1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
24 #include "tensorflow/compiler/jit/xla_cluster_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/graph/control_flow.h"
29 #include "tensorflow/core/graph/graph_node_util.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32
33 // ALGORITHM OVERVIEW
34 // ==================
35 //
36 // We map every output produced by each node in the TensorFlow graph (including
37 // control dependence) into an instance of the Predicate class. Instances of
38 // Predicate denote logical formulas and mapping a node `n` to a predicate
39 // `pred` implies that `n` is live whenever `pred` is true. Then we can deduce
40 // mismatching liveness in the inputs to node by comparing the predicate those
41 // inputs are mapped to. The core logic of this pass resides in creating the
42 // map from TensorFlow nodes to predicates.
43 //
44 //
45 // MAPPING NODES TO PREDICATES, MODULO CYCLES
46 // ------------------------------------------
47 //
48 // If we ignore cycles for a moment, computing predicates is fairly
49 // straightforward. We traverse the graph in a topological order, mapping each
50 // node to a predicate based on the predicates its inputs are mapped to. For
51 // instance a Merge(X, Y) node will be mapped to OR(PredicateFor(X),
52 // PredicateFor(Y)). Roughtly speaking, we abstractly interpret each node on
53 // the "liveness" domain, where values in the domain represent if a tensor
54 // carries a dead signal or not.
55 //
56 //
57 // DEALING WITH CYCLES
58 // -------------------
59 //
60 // We map Merge nodes that are the target of a backedge to AndRecurrence
61 // instances. An AndRecurrence with start() = S and step() = X, printed as
62 // {S,&,X}, *roughly* represents the infinite list of predicates
63 // [S,S&X,S&X&X,S&X&X, ...]. So {S,&,X} can be used to represent the predicate
64 // for Merge in a graph like:
65 //
66 // Init
67 // |
68 // v
69 // Merge <-----------+
70 // | |
71 // v |
72 // Incr |
73 // | |
74 // v |
75 // Switch <- Cond |
76 // | |
77 // v (oidx: 1) |
78 // | |
79 // +---------------+
80 //
81 // Where S is the predicate for Init and X is the predicate that asserts that
82 // Cond is true. {S,&,X} states that Merge is live on the first "iteration" iff
83 // S is true, live on the second iteration iff "S&X" is true, live on the third
84 // iteration iff "S&X&X" is true etc. There is a subtlety here, S&X&X would
85 // normally be equivalent to S&X which isn't quite what we want to represent.
86 // Instead we want {S,&,X} to denote the infinite list [S, S&X,
87 // S&X&X',S&X&X'&X'', ...] where X, X', X'' are predicates that assert Cond is
88 // true on iteration 0, 1, 2 respectively. This is made more precise in the
89 // comment on the AndRecurrence class.
90 //
91 // The general algorithm that deals with cycles does two topological-order
92 // iterations over the graph. On the first iteration it assigns a symbolic
93 // predicate to merge nodes with backedges. On the second iteration it tries
94 // to pattern match the predicates for the backedges of these merges and infer
95 // an AndRecurrence for the merge. In other words, we do a data flow analysis
96 // where the data-flow lattice has two elements, Symbolic and NonSymbolic with
97 // Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are
98 // sufficient to converge.
99 //
100 // We first do an optimistic analysis and, if it does not converge, we then fall
101 // back to a pessimistic analysis. The optimistic analysis assigns the same
102 // symbolic predicate to all the merge nodes whose preceding enter nodes have
103 // the same frame name on the first iteration. On the second iteration, if all
104 // the merge nodes are pattern matched into the same AndRecurrence predicate
105 // instance, the optimistic assignment of the same symbolic predicate is correct
106 // and the analyzed result is taken.
107 //
108 // Otherwise, if the optimistic analysis fails to converge, we then obtain the
109 // result by falling back to the pessimistic analysis which assigns a unique
110 // symbolic predicate to each merge on the first iteration. We still use
111 // symbolic predicates for merges for which we can't pattern match on the
112 // backedge predicate. This is conservatively correct.
113
114 namespace tensorflow {
115
116 namespace {
117
118 using se::port::StatusOr;
119
120 // Represents a logical predicate, used as described in the algorithm overview
121 // above.
122 class Predicate {
123 public:
124 enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol, kIntSymbol };
125
126 virtual string ToString() const = 0;
127
128 // An ID assigned to the Predicate at construction time. Conceptually like a
129 // pointer, except that it is stable across runs.
id() const130 int64 id() const { return id_; }
131
132 virtual absl::Span<Predicate* const> GetOperands() const = 0;
133
134 virtual Kind kind() const = 0;
~Predicate()135 virtual ~Predicate() {}
136
137 // Invokes func on p and on all of its operands recursively. Does not invoke
138 // `func` on the same Predicate instance twice. Aborts the search if `func`
139 // returns true.
140 template <typename FunctionTy>
141 static void Visit(Predicate* p, const FunctionTy& func);
142
143 protected:
Predicate(int64 id)144 explicit Predicate(int64 id) : id_(id) {}
145
146 private:
147 const int64 id_;
148
149 TF_DISALLOW_COPY_AND_ASSIGN(Predicate);
150 };
151
152 // Represents a logical conjunction of a set of predicates.
153 class AndPredicate : public Predicate {
154 public:
AndPredicate(int64 id,std::vector<Predicate * > operands)155 explicit AndPredicate(int64 id, std::vector<Predicate*> operands)
156 : Predicate(id), operands_(std::move(operands)) {}
157
ToString() const158 string ToString() const override {
159 if (operands().empty()) {
160 return "#true";
161 }
162
163 std::vector<string> operands_str;
164 std::transform(operands().begin(), operands().end(),
165 std::back_inserter(operands_str),
166 [](Predicate* pred) { return pred->ToString(); });
167
168 return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
169 }
170
kind() const171 Kind kind() const override { return Kind::kAnd; }
172
GetOperands() const173 absl::Span<Predicate* const> GetOperands() const override {
174 return operands_;
175 }
operands() const176 absl::Span<Predicate* const> operands() const { return operands_; }
177
178 private:
179 std::vector<Predicate*> operands_;
180 };
181
182 // Represents a logical disjunction of a set of predicates.
183 class OrPredicate : public Predicate {
184 public:
OrPredicate(int64 id,std::vector<Predicate * > operands)185 explicit OrPredicate(int64 id, std::vector<Predicate*> operands)
186 : Predicate(id), operands_(std::move(operands)) {}
187
ToString() const188 string ToString() const override {
189 if (operands().empty()) {
190 return "#false";
191 }
192
193 std::vector<string> operands_str;
194 std::transform(operands().begin(), operands().end(),
195 std::back_inserter(operands_str),
196 [](Predicate* pred) { return pred->ToString(); });
197
198 return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
199 }
200
kind() const201 Kind kind() const override { return Kind::kOr; }
GetOperands() const202 absl::Span<Predicate* const> GetOperands() const override {
203 return operands_;
204 }
operands() const205 absl::Span<Predicate* const> operands() const { return operands_; }
206
207 private:
208 std::vector<Predicate*> operands_;
209 };
210
211 // Represents a logical negation of a set of predicates.
212 class NotPredicate : public Predicate {
213 public:
NotPredicate(int64 id,Predicate * operand)214 explicit NotPredicate(int64 id, Predicate* operand)
215 : Predicate(id), operands_({operand}) {}
216
ToString() const217 string ToString() const override {
218 return absl::StrCat("~", operand()->ToString());
219 }
220
kind() const221 Kind kind() const override { return Kind::kNot; }
operand() const222 Predicate* operand() const { return operands_[0]; }
GetOperands() const223 absl::Span<Predicate* const> GetOperands() const override {
224 return operands_;
225 }
226
227 private:
228 std::array<Predicate*, 1> operands_;
229 };
230
231 // Represents the liveness of an induction variable. For users inside the loop
232 // this represents the "current" liveness of the induction variable. For users
233 // outside the loop it represents the "last" liveness of the induction variable.
234 //
235 // More concretely, an and recurrence {S,&,X}<loop> represents the liveness of V
236 // in the following graph:
237 //
238 // V = Merge(S', V_NextIt)
239 // V = Op(V, X')
240 // V_NextIt = NextIteration(V)
241 //
242 // where Predicate(S') = S and Predicate(X') = X.
243 //
244 // `X` may contain symbolic predicates and the operations corresponding to these
245 // symbolic predicates are either in frame `loop` or outside it. The symbols
246 // that are inside frame `loop` are loop variant (i.e. can have different
247 // liveness in each loop iteration) and the symbols that are outside frame
248 // `loop` are loop invariant (i.e. have the same liveness across all
249 // iterations).
250 class AndRecurrencePredicate : public Predicate {
251 public:
AndRecurrencePredicate(int64 id,Predicate * start,Predicate * step,std::vector<string> frame)252 explicit AndRecurrencePredicate(int64 id, Predicate* start, Predicate* step,
253 std::vector<string> frame)
254 : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {}
255
start() const256 Predicate* start() const { return operands_[0]; }
step() const257 Predicate* step() const { return operands_[1]; }
frame() const258 absl::Span<const string> frame() const { return frame_; }
259
ToString() const260 string ToString() const override {
261 return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
262 "}<", absl::StrJoin(frame(), ";"), ">");
263 }
264
kind() const265 Kind kind() const override { return Kind::kAndRecurrence; }
266
GetOperands() const267 absl::Span<Predicate* const> GetOperands() const override {
268 return operands_;
269 }
270
271 private:
272 std::array<Predicate*, 2> operands_;
273 std::vector<string> frame_;
274 };
275
276 // Represents an uninterpreted symbol in a logical predicate.
277 //
278 // Two predicates are equivalent iff they are equivalent for all assignments to
279 // the symbols contained in them, i.e. predicates are forall qualified over
280 // symbols.
281 class SymbolPredicate : public Predicate {
282 public:
SymbolPredicate(int64 id,TensorId tensor_id,bool must_be_true)283 explicit SymbolPredicate(int64 id, TensorId tensor_id, bool must_be_true)
284 : Predicate(id),
285 tensor_id_(std::move(tensor_id)),
286 must_be_true_(must_be_true) {}
287
ToString() const288 string ToString() const override {
289 return must_be_true() ? absl::StrCat("*", tensor_id_.ToString())
290 : tensor_id_.ToString();
291 }
292
kind() const293 Kind kind() const override { return Kind::kSymbol; }
GetOperands() const294 absl::Span<Predicate* const> GetOperands() const override { return {}; }
295
296 // If `must_be_true()` is true this SymbolPredicate represents the proposition
297 // "tensor_id() is live and evaluates to true".
298 //
299 // If `must_be_true()` is false then this SymbolPredicate represents the
300 // proposition "tensor_id() is live (and may evaluate to any value)"
tensor_id() const301 TensorId tensor_id() const { return tensor_id_; }
must_be_true() const302 bool must_be_true() const { return must_be_true_; }
303
304 private:
305 TensorId tensor_id_;
306 bool must_be_true_;
307 };
308
309 // Represents an uninterpreted symbol in a logical predicate.
310 //
311 // Two predicates are equivalent iff they are equivalent for all assignments to
312 // the symbols contained in them, i.e. predicates are forall qualified over
313 // symbols.
314 class IntSymbolPredicate : public Predicate {
315 public:
IntSymbolPredicate(int64 id,TensorId tensor_id,absl::optional<int> must_have_value)316 explicit IntSymbolPredicate(int64 id, TensorId tensor_id,
317 absl::optional<int> must_have_value)
318 : Predicate(id),
319 tensor_id_(std::move(tensor_id)),
320 must_have_value_(must_have_value) {}
321
ToString() const322 string ToString() const override {
323 return must_have_value().has_value()
324 ? absl::StrCat(tensor_id_.ToString(), "=", *must_have_value_)
325 : tensor_id_.ToString();
326 }
327
kind() const328 Kind kind() const override { return Kind::kIntSymbol; }
GetOperands() const329 absl::Span<Predicate* const> GetOperands() const override { return {}; }
330
331 // If `must_have_value().has_value()` is true, then this IntSymbolPredicate
332 // represents the proposition "tensor_id() is live and evaluates to
333 // `*must_have_value()`".
334 //
335 // If `must_have_value().has_value()` is false, then this IntSymbolPredicate
336 // represents the proposition "tensor_id() is live (and may evaluate to any
337 // value)".
tensor_id() const338 TensorId tensor_id() const { return tensor_id_; }
must_have_value() const339 const absl::optional<int>& must_have_value() const {
340 return must_have_value_;
341 }
342
343 private:
344 TensorId tensor_id_;
345 absl::optional<int> must_have_value_;
346 };
347
348 template <typename FunctionTy>
Visit(Predicate * p,const FunctionTy & func)349 /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
350 absl::flat_hash_set<Predicate*> visited;
351 std::vector<Predicate*> stack;
352
353 stack.push_back(p);
354 visited.insert(p);
355
356 while (!stack.empty()) {
357 Predicate* current = stack.back();
358 stack.pop_back();
359 bool done = func(current);
360 if (done) {
361 return;
362 }
363 for (Predicate* op : current->GetOperands()) {
364 if (visited.insert(op).second) {
365 stack.push_back(op);
366 }
367 }
368 }
369 }
370
371 // Creates and owns Predicate instances. Simplifies predicates as it creates
372 // them.
373 class PredicateFactory {
374 public:
MakeAndPredicate(absl::Span<Predicate * const> operands)375 Predicate* MakeAndPredicate(absl::Span<Predicate* const> operands) {
376 return MakeAndOrImpl(operands, /*is_and=*/true);
377 }
378
MakeOrPredicate(absl::Span<Predicate * const> operands)379 Predicate* MakeOrPredicate(absl::Span<Predicate* const> operands) {
380 return MakeAndOrImpl(operands, /*is_and=*/false);
381 }
382
MakeNotPredicate(Predicate * pred)383 Predicate* MakeNotPredicate(Predicate* pred) {
384 auto it = make_not_predicate_cache_.find(pred);
385 if (it != make_not_predicate_cache_.end()) {
386 return it->second;
387 }
388
389 Predicate* result = MakeNotPredicateImpl(pred);
390
391 bool insert_successful =
392 make_not_predicate_cache_.insert({pred, result}).second;
393 (void)insert_successful;
394 DCHECK(insert_successful);
395
396 return result;
397 }
398
MakeAndRecurrencePredicate(Predicate * start,Predicate * step,std::vector<string> frame)399 Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step,
400 std::vector<string> frame) {
401 SignatureForAndRec signature(start, step, std::move(frame));
402 auto it = interned_and_rec_instances_.find(signature);
403 if (it != interned_and_rec_instances_.end()) {
404 return it->second.get();
405 }
406
407 std::unique_ptr<Predicate> new_pred = Make<AndRecurrencePredicate>(
408 std::get<0>(signature), std::get<1>(signature), std::get<2>(signature));
409 Predicate* new_pred_ptr = new_pred.get();
410 bool inserted =
411 interned_and_rec_instances_.emplace(signature, std::move(new_pred))
412 .second;
413 (void)inserted;
414 DCHECK(inserted);
415 return new_pred_ptr;
416 }
417
MakeSymbolPredicate(Node * node,int output_idx,bool must_be_true,Predicate ** predicate)418 Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true,
419 Predicate** predicate) {
420 TensorId tensor_id(node->name(), output_idx);
421
422 bool is_boolean_tensor =
423 BaseType(node->output_type(tensor_id.index())) == DT_BOOL;
424 TF_RET_CHECK(!must_be_true || is_boolean_tensor);
425
426 if (node->type_string() == "Const" && must_be_true) {
427 const TensorProto* proto = nullptr;
428 TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto));
429
430 Tensor tensor(proto->dtype());
431 TF_RET_CHECK(tensor.FromProto(*proto));
432
433 *predicate = tensor.scalar<bool>()() ? MakeTrue() : MakeFalse();
434 return Status::OK();
435 }
436
437 SignatureForSymbol signature = {tensor_id, must_be_true};
438 auto it = interned_symbol_instances_.find(signature);
439 if (it == interned_symbol_instances_.end()) {
440 std::unique_ptr<Predicate> new_pred =
441 Make<SymbolPredicate>(tensor_id, must_be_true);
442 Predicate* new_pred_ptr = new_pred.get();
443 interned_symbol_instances_.emplace(std::move(signature),
444 std::move(new_pred));
445 *predicate = new_pred_ptr;
446 } else {
447 *predicate = it->second.get();
448 }
449
450 return Status::OK();
451 }
452
MakeSymbolPredicate(Node * node,int output_idx,absl::optional<int> must_have_value,Predicate ** predicate)453 Status MakeSymbolPredicate(Node* node, int output_idx,
454 absl::optional<int> must_have_value,
455 Predicate** predicate) {
456 TensorId tensor_id(node->name(), output_idx);
457
458 TF_RET_CHECK(BaseType(node->output_type(tensor_id.index())) == DT_INT32);
459
460 if (must_have_value.has_value() && node->type_string() == "Const") {
461 const TensorProto* proto = nullptr;
462 TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto));
463
464 Tensor tensor(proto->dtype());
465 TF_RET_CHECK(tensor.FromProto(*proto));
466
467 *predicate = tensor.scalar<int32>()() == *must_have_value ? MakeTrue()
468 : MakeFalse();
469 return Status::OK();
470 }
471 SignatureForIntSymbol signature = {tensor_id, must_have_value};
472 auto it = interned_int_symbol_instances_.find(signature);
473 if (it == interned_int_symbol_instances_.end()) {
474 std::unique_ptr<Predicate> new_pred =
475 Make<IntSymbolPredicate>(tensor_id, must_have_value);
476 Predicate* new_pred_ptr = new_pred.get();
477 interned_int_symbol_instances_.emplace(std::move(signature),
478 std::move(new_pred));
479 *predicate = new_pred_ptr;
480 } else {
481 *predicate = it->second.get();
482 }
483
484 return Status::OK();
485 }
486
MakeTrue()487 Predicate* MakeTrue() { return MakeAndPredicate({}); }
MakeFalse()488 Predicate* MakeFalse() { return MakeOrPredicate({}); }
489
~PredicateFactory()490 ~PredicateFactory() {
491 DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?";
492 }
493
494 private:
MakeNotPredicateImpl(Predicate * pred)495 Predicate* MakeNotPredicateImpl(Predicate* pred) {
496 IncrementStackDepth stack_frame(this);
497 if (!stack_frame.HasOverflowed()) {
498 if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) {
499 return simplified;
500 }
501
502 // ~~A => A
503 if (auto* not_pred = dynamic_cast<NotPredicate*>(pred)) {
504 return not_pred->operand();
505 }
506 }
507
508 SignatureForNot signature = pred;
509 auto it = interned_not_instances_.find(signature);
510 if (it == interned_not_instances_.end()) {
511 std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
512 Predicate* new_pred_ptr = new_pred.get();
513 interned_not_instances_.emplace(signature, std::move(new_pred));
514 return new_pred_ptr;
515 } else {
516 return it->second.get();
517 }
518 }
519
SimplifyUsingDeMorgan(Predicate * pred)520 Predicate* SimplifyUsingDeMorgan(Predicate* pred) {
521 // ~(A & B & C & ...) => ~A | ~B | ~C | ~...
522 // ~(A | B | C | ...) -> ~A & ~B & ~C & ~...
523 Predicate::Kind kind = pred->kind();
524
525 if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) {
526 std::vector<Predicate*> new_operands;
527 absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands),
528 [&](Predicate* p) { return MakeNotPredicate(p); });
529 return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands)
530 : MakeOrPredicate(new_operands);
531 }
532
533 return nullptr;
534 }
535
536 template <typename PredicateT, typename... Args>
Make(Args &&...args)537 std::unique_ptr<Predicate> Make(Args&&... args) {
538 // If we ever expose the Predicate class outside this .cc file then we may
539 // want to make this hard to misuse (by accidentally passing in an arbitrary
540 // integer to the Predicate constructor for instance).
541 return std::unique_ptr<PredicateT>(
542 new PredicateT(id_counter_++, std::forward<Args>(args)...));
543 }
544
545 Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
546 Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops,
547 Predicate::Kind pred_kind);
548
549 // Predicate instances are interned, meaning that there is only a single
550 // instance of a Predicate object with a given content. This makes checking
551 // for structural equality super-cheap -- we can just compare pointers.
552 //
553 // We intern predicates by maintaining a map from the content of a Predicate
554 // to the only instance of said predicate we allow to exist in the
555 // interned_and_or_instances_, interned_not_instances_ and
556 // interned_symbol_instances_ fields. These maps also double up as storage
557 // for the owning pointers to predicate instances.
558
559 using SignatureForAndOr =
560 std::pair<Predicate::Kind, absl::Span<Predicate* const>>;
561 using SignatureForNot = Predicate*;
562 using SignatureForAndRec =
563 std::tuple<Predicate*, Predicate*, std::vector<string>>;
564 using SignatureForSymbol = std::pair<SafeTensorId, bool>;
565 using SignatureForIntSymbol = std::pair<SafeTensorId, absl::optional<int32>>;
566
567 struct HashSignatureForAndOr {
operator ()tensorflow::__anonab73ef130111::PredicateFactory::HashSignatureForAndOr568 size_t operator()(const SignatureForAndOr& signature) const {
569 size_t hash = ::tensorflow::hash<Predicate::Kind>()(signature.first);
570 for (Predicate* p : signature.second) {
571 hash = Hash64Combine(hash, ::tensorflow::hash<Predicate*>()(p));
572 }
573 return hash;
574 }
575 };
576
577 struct HashSignatureForSymbol {
operator ()tensorflow::__anonab73ef130111::PredicateFactory::HashSignatureForSymbol578 size_t operator()(const SignatureForSymbol& signature) const {
579 return Hash64Combine(SafeTensorId::Hasher()(signature.first),
580 ::tensorflow::hash<bool>()(signature.second));
581 }
582 };
583
584 struct HashSignatureForIntSymbol {
operator ()tensorflow::__anonab73ef130111::PredicateFactory::HashSignatureForIntSymbol585 size_t operator()(const SignatureForIntSymbol& signature) const {
586 return Hash64Combine(
587 SafeTensorId::Hasher()(signature.first),
588 Hash64Combine(
589 ::tensorflow::hash<bool>()(signature.second.has_value()),
590 ::tensorflow::hash<int32>()(
591 signature.second.has_value() ? *signature.second : 0)));
592 }
593 };
594
595 // Used to limit recursion to avoid blowing up the stack and cap compile time.
596 class IncrementStackDepth {
597 public:
IncrementStackDepth(PredicateFactory * parent)598 explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) {
599 parent_->stack_depth_++;
600 }
601
HasOverflowed() const602 bool HasOverflowed() const {
603 const int kMaxStackDepth = 8;
604 return parent_->stack_depth_ >= kMaxStackDepth;
605 }
606
~IncrementStackDepth()607 ~IncrementStackDepth() { parent_->stack_depth_--; }
608
609 private:
610 PredicateFactory* parent_;
611 };
612
613 // A cache for the MakeNotPredicate function.
614 //
615 // NB! This is *not* the same as `interned_not_instances_`.
616 // `interned_not_instances_` maps ensures pointer identity for `NotPredicate`
617 // instances, i.e., it ensures there at most one instance of Not(predicate)
618 // for any given predicate whereas `make_not_predicate_cache_` simply caches
619 // the result of the `MakeNotPredicate` function. The values in
620 // `interned_not_instances_` are always instance of `NotPredicate` whereas the
621 // values in `make_not_predicate_cache_` may not be (for instance it will map
622 // Not(Not(A)) to A).
623 absl::flat_hash_map<Predicate*, Predicate*> make_not_predicate_cache_;
624
625 absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
626 HashSignatureForAndOr>
627 interned_and_or_instances_;
628 absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>>
629 interned_not_instances_;
630 absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>>
631 interned_and_rec_instances_;
632 absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
633 HashSignatureForSymbol>
634 interned_symbol_instances_;
635 absl::flat_hash_map<SignatureForIntSymbol, std::unique_ptr<Predicate>,
636 HashSignatureForIntSymbol>
637 interned_int_symbol_instances_;
638 int64 id_counter_ = 0;
639 int stack_depth_ = 0;
640 };
641
MakeInternedAndOr(std::vector<Predicate * > simplified_ops,Predicate::Kind pred_kind)642 Predicate* PredicateFactory::MakeInternedAndOr(
643 std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
644 std::stable_sort(
645 simplified_ops.begin(), simplified_ops.end(),
646 [](Predicate* a, Predicate* b) { return a->id() < b->id(); });
647
648 auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
649 if (it != interned_and_or_instances_.end()) {
650 return it->second.get();
651 }
652
653 simplified_ops.shrink_to_fit();
654 // NB! Because we'll use a non-owning reference to simplified_ops in the
655 // key for interned_and_or_instances_ we need to be careful to std::move()
656 // it all the way through.
657 absl::Span<Predicate* const> operands_slice = simplified_ops;
658 std::unique_ptr<Predicate> new_pred =
659 pred_kind == Predicate::Kind::kAnd
660 ? Make<AndPredicate>(std::move(simplified_ops))
661 : Make<OrPredicate>(std::move(simplified_ops));
662
663 Predicate* new_pred_ptr = new_pred.get();
664 interned_and_or_instances_.emplace(
665 SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred));
666 return new_pred_ptr;
667 }
668
669 // Common code to create AndPredicate or OrPredicate instances.
MakeAndOrImpl(absl::Span<Predicate * const> operands,bool is_and)670 Predicate* PredicateFactory::MakeAndOrImpl(
671 absl::Span<Predicate* const> operands, bool is_and) {
672 Predicate::Kind pred_kind =
673 is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
674
675 IncrementStackDepth stack_frame(this);
676 if (stack_frame.HasOverflowed()) {
677 return MakeInternedAndOr(
678 std::vector<Predicate*>(operands.begin(), operands.end()), pred_kind);
679 }
680
681 Predicate::Kind other_pred_kind =
682 is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
683 absl::flat_hash_set<Predicate*> simplified_ops_set;
684 std::vector<Predicate*> simplified_ops;
685 for (Predicate* op : operands) {
686 // Simplify A&A => A and A|A => A.
687 if (!simplified_ops_set.insert(op).second) {
688 continue;
689 }
690
691 if (op->kind() == pred_kind) {
692 // "Inline" the operands of an inner And/Or into the parent And/Or.
693 for (Predicate* subop : op->GetOperands()) {
694 if (simplified_ops_set.insert(subop).second) {
695 simplified_ops.push_back(subop);
696 }
697 }
698 } else {
699 simplified_ops.push_back(op);
700 }
701 }
702
703 if (simplified_ops.size() == 1) {
704 return simplified_ops[0];
705 }
706
707 // Simplify "A&~A=>False" and "A|~A=>True".
708 absl::flat_hash_set<Predicate*> negated_ops;
709 for (Predicate* op : simplified_ops) {
710 if (negated_ops.count(op)) {
711 // Simple case:
712 //
713 // A & ~A & ... == False
714 // A | ~A | ... == True
715 return is_and ? MakeFalse() : MakeTrue();
716 }
717
718 Predicate* negated_op = MakeNotPredicate(op);
719 if (negated_op->kind() == pred_kind) {
720 // Slightly more complicated case:
721 //
722 // (~A | ~B | ~C) & A & B & C & ... ==
723 // ~(A & B & C) & (A & B & C) & ... == False
724 //
725 // (~A & ~B & ~C) | A | B | C | ... ==
726 // ~(A | B | C) | (A | B | C) | ... == True
727 if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) {
728 return simplified_ops_set.contains(p);
729 })) {
730 return is_and ? MakeFalse() : MakeTrue();
731 }
732 }
733 negated_ops.insert(negated_op);
734 }
735
736 // Simplify {S,&,X} & ~X & ... => S & ...
737 if (is_and) {
738 absl::flat_hash_set<Predicate*> to_remove;
739 std::vector<Predicate*> to_add;
740 for (Predicate* op : simplified_ops) {
741 if (op->kind() == Predicate::Kind::kAndRecurrence) {
742 auto* and_rec = static_cast<AndRecurrencePredicate*>(op);
743 if (negated_ops.contains(and_rec->step())) {
744 // Remove and_rec and ~X and insert S. Note that checking the
745 // existence of ~X through negated_ops is sufficient since it makes
746 // sure the predicate is in the input operands. It does not need to
747 // be in simplified_ops if it was already cancelled out.
748 to_remove.insert(and_rec);
749 to_remove.insert(MakeNotPredicate(and_rec->step()));
750 to_add.push_back(and_rec->start());
751 }
752 }
753 }
754 auto it = simplified_ops.begin();
755 while (it != simplified_ops.end()) {
756 if (to_remove.contains(*it)) {
757 it = simplified_ops.erase(it);
758 } else {
759 ++it;
760 }
761 }
762 simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end());
763 }
764
765 // If all ops contain the same subop, then factor it out thanks to the
766 // distributive property. Such as:
767 // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
768 // - (A | B) & (A | C) & (A | D) => A | (B & C & D)
769 //
770 // First find any predicates contained in all subops.
771 std::vector<Predicate*> common_inner_operands;
772 absl::flat_hash_set<Predicate*> common_inner_operands_set;
773 for (Predicate* op : simplified_ops) {
774 if (op->kind() != other_pred_kind) {
775 common_inner_operands.clear();
776 break;
777 }
778
779 if (common_inner_operands.empty()) {
780 common_inner_operands.insert(common_inner_operands.end(),
781 op->GetOperands().begin(),
782 op->GetOperands().end());
783 } else {
784 common_inner_operands.clear();
785 absl::c_copy_if(op->GetOperands(),
786 std::back_inserter(common_inner_operands),
787 [&](Predicate* sub_op) {
788 return common_inner_operands_set.count(sub_op) == 1;
789 });
790 }
791 if (common_inner_operands.empty()) break;
792 common_inner_operands_set.clear();
793 common_inner_operands_set.insert(common_inner_operands.begin(),
794 common_inner_operands.end());
795 }
796
797 if (common_inner_operands.empty()) {
798 return MakeInternedAndOr(std::move(simplified_ops), pred_kind);
799 }
800
801 // For all predicates that can be factored out, remove them and recreate the
802 // subops.
803 std::vector<Predicate*> factored_ops;
804 for (Predicate* op : simplified_ops) {
805 std::vector<Predicate*> new_sub_op_ops;
806 absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops),
807 [&](Predicate* sub_op) {
808 return std::find(common_inner_operands.begin(),
809 common_inner_operands.end(),
810 sub_op) == common_inner_operands.end();
811 });
812 factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and));
813 }
814
815 Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and);
816 std::vector<Predicate*> outer_ops;
817 outer_ops.push_back(new_inner_op);
818 outer_ops.insert(outer_ops.end(), common_inner_operands.begin(),
819 common_inner_operands.end());
820 return MakeAndOrImpl(outer_ops, !is_and);
821 }
822
823 class DeadnessAnalysisImpl : public DeadnessAnalysis {
824 public:
DeadnessAnalysisImpl(const Graph * graph)825 explicit DeadnessAnalysisImpl(const Graph* graph)
826 : graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
827
828 Status Populate(bool enable_optimistic);
829 Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
830 bool* success);
831 StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
832 Node* n, int oidx) const override;
833 void Print() const override;
834 absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
835 const;
836
837 private:
838 enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
839
840 Status GetInputPreds(Node* n, EdgeKind edge_kind,
841 std::vector<Predicate*>* result);
842
843 // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th
844 // bit of `should_revisit` if `pred` is different from the current predicate
845 // for the `output_idx` output of `n`.
SetPredicate(Node * n,int output_idx,Predicate * pred,std::vector<bool> * should_revisit)846 void SetPredicate(Node* n, int output_idx, Predicate* pred,
847 std::vector<bool>* should_revisit) {
848 auto insert_result =
849 predicate_map_.insert({TensorId(n->name(), output_idx), pred});
850 if (!insert_result.second && insert_result.first->second != pred) {
851 VLOG(4) << "For " << n->name() << ":" << output_idx << " from "
852 << insert_result.first->second->ToString() << " "
853 << insert_result.first->second << " to " << pred->ToString()
854 << " " << pred;
855 insert_result.first->second = pred;
856 if (should_revisit != nullptr) {
857 for (const Edge* e : n->out_edges()) {
858 (*should_revisit)[e->dst()->id()] = true;
859 }
860 }
861 }
862 }
863
SetPredicate(Node * n,absl::Span<const int> output_idxs,Predicate * pred,std::vector<bool> * should_revisit)864 void SetPredicate(Node* n, absl::Span<const int> output_idxs, Predicate* pred,
865 std::vector<bool>* should_revisit) {
866 for (int output_idx : output_idxs) {
867 SetPredicate(n, output_idx, pred, should_revisit);
868 }
869 }
870
871 Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
872 Status HandleMerge(Node* n, std::vector<bool>* should_revisit,
873 bool use_optimistic_mode);
874 Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
875 Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
876 Status HandleNode(Node* n, std::vector<bool>* should_revisit,
877 bool use_optimistic_mode = false);
878
879 Status GetFrameBasedTopologicalOrder(std::vector<Node*>* order);
880
IsRootEnter(const Node * n) const881 bool IsRootEnter(const Node* n) const {
882 return IsEnter(n) && control_flow_info_[n->id()].parent_frame->IsSource();
883 }
884
IsRootExit(const Node * n) const885 bool IsRootExit(const Node* n) const {
886 return IsExit(n) && control_flow_info_[n->id()].parent_frame->IsSource();
887 }
888
889 const Graph& graph_;
890 absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
891 PredicateFactory predicate_factory_;
892 std::vector<ControlFlowInfo> control_flow_info_;
893 bool vlog_;
894 absl::flat_hash_map<absl::string_view, Node*> frame_to_merge_node_;
895 };
896
InputEdgeToTensorId(const Edge * e)897 TensorId InputEdgeToTensorId(const Edge* e) {
898 return TensorId(e->src()->name(), e->src_output());
899 }
900
GetInputPreds(Node * n,DeadnessAnalysisImpl::EdgeKind edge_kind,std::vector<Predicate * > * result)901 Status DeadnessAnalysisImpl::GetInputPreds(
902 Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind,
903 std::vector<Predicate*>* result) {
904 result->clear();
905 for (const Edge* in_edge : n->in_edges()) {
906 bool should_process =
907 edge_kind == EdgeKind::kDataAndControl ||
908 (in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) ||
909 (!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly);
910
911 if (should_process) {
912 auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
913 if (it == predicate_map_.end()) {
914 GraphCycles graph_cycles;
915 TF_RETURN_IF_ERROR(
916 CreateCycleDetectionGraph(&graph_, &graph_cycles).status());
917
918 // If we didn't return with an error above then the graph is probably
919 // fine and we have a bug in deadness analysis.
920 return errors::Internal("Could not find input ", in_edge->DebugString(),
921 " to ", n->name(),
922 " when visiting the graph in post-order. Most "
923 "likely indicates a bug in deadness analysis.");
924 }
925 result->push_back(it->second);
926 }
927 }
928 return Status::OK();
929 }
930
HandleSwitch(Node * n,std::vector<bool> * should_revisit)931 Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
932 std::vector<bool>* should_revisit) {
933 std::vector<Predicate*> input_preds;
934 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
935 const Edge* pred_edge;
936 TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
937
938 if (n->type_string() != "_SwitchN") { // bool pred branch selector.
939 Predicate* true_switch;
940 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
941 pred_edge->src(), pred_edge->src_output(),
942 /*must_be_true=*/true, &true_switch));
943
944 Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
945
946 // Output 0 is alive iff all inputs are alive and the condition is false.
947 input_preds.push_back(false_switch);
948 SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
949 should_revisit);
950 input_preds.pop_back();
951
952 // Output 1 is alive iff all inputs are alive and the condition is true.
953 input_preds.push_back(true_switch);
954 SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
955 should_revisit);
956 input_preds.pop_back();
957 } else { // N-way switch case. Exactly one of N branches is alive.
958 Predicate* branch_pred;
959 for (int i = 0; i < n->num_outputs() - 1; i++) {
960 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
961 pred_edge->src(), pred_edge->src_output(),
962 /*must_have_value=*/absl::optional<int32>(i), &branch_pred));
963 input_preds.push_back(branch_pred);
964 SetPredicate(n, i, predicate_factory_.MakeAndPredicate(input_preds),
965 should_revisit);
966 input_preds.pop_back();
967 input_preds.push_back(predicate_factory_.MakeNotPredicate(branch_pred));
968 }
969 // The default (last) branch does not need its own symbol, is simply the
970 // nor of all other branches.
971 SetPredicate(n, n->num_outputs() - 1,
972 predicate_factory_.MakeAndPredicate(input_preds),
973 should_revisit);
974 }
975
976 // Control is alive iff all inputs are alive.
977 SetPredicate(n, Graph::kControlSlot,
978 predicate_factory_.MakeAndPredicate(input_preds),
979 should_revisit);
980
981 return Status::OK();
982 }
983
984 namespace {
CreateMultipleNextIterationInputsError(Node * merge)985 Status CreateMultipleNextIterationInputsError(Node* merge) {
986 std::vector<string> backedges;
987 for (const Edge* backedge : merge->in_edges()) {
988 if (backedge->src()->IsNextIteration()) {
989 backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src())));
990 }
991 }
992 return errors::InvalidArgument(
993 "Multiple NextIteration inputs to merge node ",
994 FormatNodeForError(*merge), ": \n", absl::StrJoin(backedges, "\n"),
995 "\nMerge nodes can have at most one incoming NextIteration edge.");
996 }
997
FindUniqueBackedge(Node * merge,const Edge ** result)998 Status FindUniqueBackedge(Node* merge, const Edge** result) {
999 *result = nullptr;
1000 CHECK(merge->IsMerge());
1001 for (const Edge* e : merge->in_edges()) {
1002 if (e->src()->IsNextIteration()) {
1003 if (*result != nullptr) {
1004 return CreateMultipleNextIterationInputsError(merge);
1005 }
1006 *result = e;
1007 }
1008 }
1009 return Status::OK();
1010 }
1011
1012 // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step
1013 // does not contain `symbolic_predicate` as an inner (not top-level) operand
1014 // then returns `Step`. Otherwise returns nullptr.
DeduceStepPredicate(PredicateFactory * predicate_factory,Predicate * symbolic_predicate,Predicate * backedge_predicate)1015 Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory,
1016 Predicate* symbolic_predicate,
1017 Predicate* backedge_predicate) {
1018 CHECK(dynamic_cast<SymbolPredicate*>(symbolic_predicate));
1019 if (backedge_predicate->kind() != Predicate::Kind::kAnd) {
1020 return nullptr;
1021 }
1022
1023 std::vector<Predicate*> and_ops;
1024 absl::Span<Predicate* const> recurrent_pred_ops =
1025 backedge_predicate->GetOperands();
1026
1027 bool found_sym = false;
1028 for (Predicate* and_op : recurrent_pred_ops) {
1029 // We want the `symbol_predicate` to be the one of the operands of
1030 // `backedge_predicate`,
1031 if (and_op == symbolic_predicate) {
1032 found_sym = true;
1033 continue;
1034 }
1035
1036 // but we don't want it to be present anywhere else in the formula. E.g. we
1037 // don't want the recurrent predicate to be
1038 // symbol_predicate&(X|symbol_predicate).
1039 bool found_sym_as_inner_operand = false;
1040 auto has_self_as_inner_operand = [&](Predicate* p) {
1041 if (p == symbolic_predicate) {
1042 found_sym_as_inner_operand = true;
1043 return true; // Stop searching, we're done.
1044 }
1045
1046 // Continue searching.
1047 return false;
1048 };
1049
1050 Predicate::Visit(and_op, has_self_as_inner_operand);
1051 if (found_sym_as_inner_operand) {
1052 return nullptr;
1053 }
1054 and_ops.push_back(and_op);
1055 }
1056
1057 return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr;
1058 }
1059
GetFullFrame(const Node * n,absl::Span<const ControlFlowInfo> cfi_infos,std::vector<string> * frame)1060 Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
1061 std::vector<string>* frame) {
1062 int depth = 0;
1063 for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource();
1064 n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) {
1065 frame->push_back(cfi_iter->frame_name);
1066
1067 if (depth++ > 5000) {
1068 return errors::Internal(
1069 "Frame of depth > 5000: Probably malformed graph or a bug in "
1070 "BuildControlFlowInfo");
1071 }
1072 }
1073
1074 return Status::OK();
1075 }
1076
1077 // If the node is inside some frames, get the name of the outermost non-empty
1078 // frame. Otherwise, get an empty frame name.
GetRootFrame(const Node * n,absl::Span<const ControlFlowInfo> cfi_infos,absl::string_view * frame)1079 Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
1080 absl::string_view* frame) {
1081 int depth = 0;
1082 const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
1083 while (!cfi_iter->parent_frame->IsSource()) {
1084 n = cfi_iter->parent_frame;
1085 cfi_iter = &cfi_infos[n->id()];
1086
1087 if (depth++ > 5000) {
1088 return errors::Internal(
1089 "Frame of depth > 5000: Probably malformed graph or a bug in "
1090 "BuildControlFlowInfo");
1091 }
1092 }
1093
1094 *frame = cfi_iter->frame_name;
1095 return Status::OK();
1096 }
1097 } // namespace
1098
HandleMerge(Node * n,std::vector<bool> * should_revisit,bool use_optimistic_mode)1099 Status DeadnessAnalysisImpl::HandleMerge(Node* n,
1100 std::vector<bool>* should_revisit,
1101 bool use_optimistic_mode) {
1102 // Merge ignores deadness of its control inputs. A merge that isn't the
1103 // target of a backedge has is alive iff any of its data inputs are. The
1104 // liveness of a merge that is the target of a backedge can sometimes be
1105 // represented using a AndRecurrencePredicate. If neither apply, we represent
1106 // the liveness of the merge symbolically.
1107
1108 bool has_unvisited_backedge = false;
1109 for (const Edge* e : n->in_edges()) {
1110 if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
1111 has_unvisited_backedge |= !predicate_map_.count(InputEdgeToTensorId(e));
1112 }
1113 }
1114
1115 auto it = predicate_map_.find(TensorId(n->name(), 0));
1116 if (it == predicate_map_.end()) {
1117 if (has_unvisited_backedge) {
1118 // We're visiting this merge for the first time and it has an unvisited
1119 // backedge.
1120 Predicate* input_data_pred;
1121 if (use_optimistic_mode) {
1122 // In the optimistic mode, we use the first-seen Merge node per
1123 // frame as the representative Merge node. It is just convenient and
1124 // does not affect the result after pattern-matching into the
1125 // AndRecurrence form.
1126 absl::string_view frame_name = control_flow_info_[n->id()].frame_name;
1127 auto insert_result = frame_to_merge_node_.insert({frame_name, n});
1128 Node* representative = insert_result.first->second;
1129 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1130 representative, /*output_idx=*/0, /*must_be_true=*/false,
1131 &input_data_pred));
1132 } else {
1133 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1134 n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
1135 }
1136
1137 SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
1138 should_revisit);
1139 return Status::OK();
1140 }
1141
1142 std::vector<Predicate*> input_preds;
1143 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
1144
1145 // We're visiting this merge for the first time and it is an acyclic merge.
1146 Predicate* input_data_pred =
1147 predicate_factory_.MakeOrPredicate(input_preds);
1148 SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
1149 should_revisit);
1150 return Status::OK();
1151 }
1152
1153 if (it->second->kind() == Predicate::Kind::kSymbol) {
1154 // Last time we visited this merge we only got a symbolic predicate because
1155 // of an unvisited backedge. Try to pattern match the predicate expression
1156 // for that backedge (which should be visited now) into an and recurrence
1157 // for the merge node.
1158 const Edge* unique_backedge;
1159 TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge));
1160 if (unique_backedge) {
1161 if (Predicate* step = DeduceStepPredicate(
1162 &predicate_factory_, it->second,
1163 predicate_map_[InputEdgeToTensorId(unique_backedge)])) {
1164 // If the predicate for the backedge is "Sym&X" where "Sym" is the
1165 // predicate for the merge then the merge has predicate {S,&,X} where S
1166 // is the predicate for the merge ignoring the backedge.
1167 std::vector<Predicate*> non_recurrent_inputs;
1168 for (const Edge* e : n->in_edges()) {
1169 if (e != unique_backedge) {
1170 non_recurrent_inputs.push_back(
1171 predicate_map_[InputEdgeToTensorId(e)]);
1172 }
1173 }
1174
1175 Predicate* start =
1176 predicate_factory_.MakeOrPredicate(non_recurrent_inputs);
1177 std::vector<string> frame;
1178 TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame));
1179 Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate(
1180 start, step, std::move(frame));
1181 SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
1182 return Status::OK();
1183 }
1184 }
1185 }
1186 return Status::OK();
1187 }
1188
HandleRecv(Node * n,std::vector<bool> * should_revisit)1189 Status DeadnessAnalysisImpl::HandleRecv(Node* n,
1190 std::vector<bool>* should_revisit) {
1191 // In addition to being alive or dead based on the inputs, a _Recv can also
1192 // acquire a dead signal from a _Send.
1193 std::vector<Predicate*> input_preds;
1194 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
1195 Predicate* signal_is_alive;
1196 TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1197 n, /*output_idx=*/0, /*must_be_true=*/false, &signal_is_alive));
1198 input_preds.push_back(signal_is_alive);
1199 SetPredicate(n, {0, Graph::kControlSlot},
1200 predicate_factory_.MakeAndPredicate(input_preds),
1201 should_revisit);
1202 return Status::OK();
1203 }
1204
HandleGeneric(Node * n,std::vector<bool> * should_revisit)1205 Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
1206 std::vector<bool>* should_revisit) {
1207 // Generally nodes are alive iff all their inputs are alive.
1208 std::vector<Predicate*> input_preds;
1209 TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
1210 Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds);
1211 for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
1212 SetPredicate(n, output_idx, pred, should_revisit);
1213 }
1214 SetPredicate(n, Graph::kControlSlot, pred, should_revisit);
1215 return Status::OK();
1216 }
1217
HandleNode(Node * n,std::vector<bool> * should_revisit,bool use_optimistic_mode)1218 Status DeadnessAnalysisImpl::HandleNode(Node* n,
1219 std::vector<bool>* should_revisit,
1220 bool use_optimistic_mode) {
1221 if (n->IsSwitch()) {
1222 TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
1223 } else if (n->IsMerge()) {
1224 TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode));
1225 } else if (n->IsControlTrigger()) {
1226 SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
1227 nullptr);
1228 } else if (n->IsRecv() || n->IsHostRecv()) {
1229 TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit));
1230 } else if (n->IsNextIteration()) {
1231 TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1232 } else {
1233 TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1234 }
1235 return Status::OK();
1236 }
1237
1238 // Compute a special topological order for the Graph, where nodes having the
1239 // same root frame are placed adjacent to each other. The traversal uses a
1240 // variant of Kahn's algorithm. num_ready_inputs is used to keep track of how
1241 // many inputs of each node are ready; a node is ready to be scheduled if all
1242 // of its inputs are ready.
1243 // Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details.
GetFrameBasedTopologicalOrder(std::vector<Node * > * order)1244 Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder(
1245 std::vector<Node*>* order) {
1246 absl::flat_hash_map<absl::string_view, size_t> num_enters_for_frame;
1247 absl::flat_hash_map<absl::string_view, size_t> num_exits_for_frame;
1248 std::vector<size_t> num_ready_inputs(graph_.num_node_ids(), 0);
1249 Node* src_node = graph_.source_node();
1250 for (const auto* node : graph_.op_nodes()) {
1251 const ControlFlowInfo& cf = control_flow_info_[node->id()];
1252 if (IsRootEnter(node)) {
1253 // Since we care only the root-level frame, full frame names are the same
1254 // as frame names.
1255 ++num_enters_for_frame[cf.frame_name];
1256 } else if (IsRootExit(node)) {
1257 ++num_exits_for_frame[cf.frame_name];
1258 }
1259 // Edge NextIteration->Merge is counted before starting the traversal to
1260 // break the backedges.
1261 if (IsMerge(node)) {
1262 for (const Edge* e : node->in_edges()) {
1263 if (IsNextIteration(e->src())) {
1264 ++num_ready_inputs[node->id()];
1265 }
1266 }
1267 }
1268 }
1269
1270 // dequeue is used to ensure that the nodes are first-in-first-out. This
1271 // order guarantees that the exits in the ready queue are visited before
1272 // nodes that will become ready in the future.
1273 std::deque<Node*> ready;
1274 ready.push_back(src_node);
1275 // ready_enters_per_frame and ready_exits serve as a staging area to buffer
1276 // the ready enters/exits before they are moved to the `ready` queue for
1277 // controlling the start and end of a processing frame.
1278 absl::flat_hash_map<absl::string_view, std::vector<Node*>>
1279 ready_enters_per_frame;
1280 // Exit nodes shall all be from the same frame, as we process a frame at a
1281 // time. So, one vector is enough.
1282 std::vector<Node*> ready_exits;
1283 while (!ready.empty()) {
1284 Node* curr_node = ready.front();
1285 ready.pop_front();
1286
1287 VLOG(4) << "Visiting " << curr_node->name();
1288 order->push_back(curr_node);
1289
1290 for (const Edge* out_edge : curr_node->out_edges()) {
1291 Node* out = out_edge->dst();
1292 int out_id = out->id();
1293 if (IsNextIteration(curr_node) && IsMerge(out)) {
1294 // Edge NextIteration->Merge has been counted.
1295 continue;
1296 }
1297 ++num_ready_inputs[out->id()];
1298 if (!out->IsOp()) continue; // Skip Sink/Source nodes.
1299 if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
1300
1301 absl::string_view frame_name = control_flow_info_[out_id].frame_name;
1302 if (IsRootEnter(out)) {
1303 ready_enters_per_frame[frame_name].push_back(out);
1304 } else if (IsRootExit(out)) {
1305 ready_exits.push_back(out);
1306 } else {
1307 ready.push_back(out);
1308 }
1309 }
1310
1311 if (ready.empty()) {
1312 // Try moving nodes from ready_enters_per_frame and ready_exits to
1313 // `ready`.
1314 if (!ready_exits.empty()) {
1315 // If there are nodes in ready_exits we must process them before
1316 // processing ready_enters_per_frame to make sure all nodes in the
1317 // currently processing frame are visited before starting processing
1318 // other frames.
1319 absl::string_view frame_name =
1320 control_flow_info_[ready_exits.front()->id()].frame_name;
1321 CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]);
1322 ready.insert(ready.end(), ready_exits.begin(), ready_exits.end());
1323 ready_exits.clear();
1324 } else {
1325 // Otherwise, try moving nodes from ready_enters to `ready`.
1326 for (auto iter = ready_enters_per_frame.begin();
1327 iter != ready_enters_per_frame.end(); ++iter) {
1328 absl::string_view frame_name = iter->first;
1329 const std::vector<Node*>& ready_enters = iter->second;
1330 if (ready_enters.size() == num_enters_for_frame[frame_name]) {
1331 ready.insert(ready.end(), ready_enters.begin(), ready_enters.end());
1332 ready_enters_per_frame.erase(iter);
1333 break;
1334 }
1335 }
1336 }
1337 }
1338 }
1339
1340 if (!ready_enters_per_frame.empty() || !ready_exits.empty()) {
1341 return errors::InvalidArgument(
1342 "Some enters/exits have never been visited in the traversal."
1343 " Most probably the input graph is malformed.");
1344 }
1345 return Status::OK();
1346 }
1347
1348 // We populate the nodes along a special topological order where nodes having
1349 // the same root frame are placed adjacent to each other. This grouping enables
1350 // processing the graph per root frame at a time and guarantees that when a root
1351 // frame is being processed, nodes in the downstream frames have not yet been
1352 // processed. This property is important because we need to process an entire
1353 // frame to know whether the optimistic mode converges or not. In other words,
1354 // nodes in the downstream frames shall not be populated until all of its
1355 // upstream frames are populated. In effect, this order enables processing each
1356 // (nested) tf.while one-by-one, as each (nested) tf.while creates a unique
1357 // (root) frame. Note that we don't separate while loops belonging to the same
1358 // nested while, as there is no clean cut for separating them in the topological
1359 // order.
Populate(bool enable_optimistic)1360 Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) {
1361 std::vector<string> unreachable_nodes;
1362 // Compute the loop structure of the graph.
1363 TF_RETURN_IF_ERROR(
1364 BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes));
1365
1366 // Do some opportunistic error checking:
1367 if (!unreachable_nodes.empty()) {
1368 if (unreachable_nodes.size() > 5) {
1369 unreachable_nodes.erase(unreachable_nodes.begin() + 5,
1370 unreachable_nodes.end());
1371 }
1372
1373 return errors::InvalidArgument(
1374 "Found unreachable nodes, most likely source and sink nodes not "
1375 "connected: ",
1376 absl::StrJoin(unreachable_nodes, ", "));
1377 }
1378
1379 std::vector<Node*> topo;
1380 TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo));
1381
1382 size_t frame_start = 0;
1383 while (frame_start < topo.size()) {
1384 // Batching nodes who have the same root frame.
1385 absl::string_view cur_frame_name;
1386 TF_RETURN_IF_ERROR(
1387 GetRootFrame(topo[frame_start], control_flow_info_, &cur_frame_name));
1388 size_t frame_end = frame_start;
1389 for (size_t i = frame_start + 1; i < topo.size(); ++i) {
1390 absl::string_view i_frame_name;
1391 TF_RETURN_IF_ERROR(
1392 GetRootFrame(topo[i], control_flow_info_, &i_frame_name));
1393 if (i_frame_name == cur_frame_name) {
1394 frame_end = i;
1395 } else {
1396 break;
1397 }
1398 }
1399 absl::Span<Node*> sub_topo(topo.data() + frame_start,
1400 /*length=*/frame_end - frame_start + 1);
1401 frame_start = frame_end + 1;
1402
1403 // First, try the optimistic mode.
1404 bool success = false;
1405 if (enable_optimistic && !cur_frame_name.empty()) {
1406 TF_RETURN_IF_ERROR(
1407 PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
1408 }
1409 if (!success) {
1410 // The optimistic mode does not converge. Let's fall back to the
1411 // pessimistic mode.
1412 TF_RETURN_IF_ERROR(
1413 PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr));
1414 }
1415 VLOG(2) << "Done populating frame " << cur_frame_name << " using the "
1416 << (success ? "optimistic" : "pessimistic") << " mode.";
1417 }
1418
1419 return Status::OK();
1420 }
1421
PopulateFrame(absl::Span<Node * const> topo,bool use_optimistic_mode,bool * success)1422 Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
1423 bool use_optimistic_mode,
1424 bool* success) {
1425 CHECK(use_optimistic_mode && success != nullptr ||
1426 !use_optimistic_mode && success == nullptr);
1427
1428 // This an abstract interpretation over the deadness propagation semantics of
1429 // the graph executor.
1430 //
1431 // We iterate over the graph twice, each time in a topological order. On the
1432 // first iteration merge nodes with backedges are mapped to symbolic
1433 // predicates. On the second iteration we use the predicates assigned to the
1434 // backedges in the previous iteration to infer a more precise predicate for
1435 // the backedge merge nodes and all the nodes that transitively use it.
1436 //
1437 // We don't track the output indices for should_revisit. Instead, putting a
1438 // node in `should_revisit` denotes that the deadness flowing out from any
1439 // output from said node may have changed. This is fine; only switches
1440 // propagate different deadness along different output edges, and since the
1441 // delta is solely due to the input *values* (and not input deadness), the
1442 // delta should not change in the second iteration.
1443 std::vector<bool> should_revisit;
1444 should_revisit.resize(graph_.num_node_ids());
1445 for (Node* n : topo) {
1446 VLOG(4) << "Visiting " << n->name();
1447 TF_RETURN_IF_ERROR(
1448 HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
1449 if (n->IsNextIteration()) {
1450 // If this is a backedge for a merge node then remember to reprocess the
1451 // merge the next time we run.
1452 for (const Edge* e : n->out_edges()) {
1453 if (e->dst()->IsMerge()) {
1454 should_revisit[e->dst()->id()] = true;
1455 }
1456 }
1457 }
1458 }
1459
1460 for (Node* n : topo) {
1461 // The nodes added to should_revisit in the previous loop need to be
1462 // revisited now. Reprocessing these initial nodes may add *their*
1463 // consumers to should_revisit, and these newly added nodes will also be
1464 // processed by this very same loop. Since we're traversing the graph in
1465 // topological order (producers before consumers) and HandleNode(n) can only
1466 // ever add n's consumers to should_revisit, we won't "miss" an addition to
1467 // should_revisit.
1468 if (should_revisit[n->id()]) {
1469 VLOG(4) << "Revisiting " << n->name();
1470 TF_RETURN_IF_ERROR(HandleNode(n, &should_revisit));
1471 }
1472 }
1473
1474 // Check if the optimistic analysis converges. Specifically, check whether
1475 // all the predicates of the merge nodes in the same frame are the same. If
1476 // yes, report success. If not, report failure and clear the assigned
1477 // predicates.
1478 if (use_optimistic_mode) {
1479 bool is_converged = true;
1480 absl::flat_hash_map<absl::string_view, Predicate*> frame_to_pred;
1481 for (Node* n : topo) {
1482 if (!n->IsMerge()) {
1483 continue;
1484 }
1485 const Edge* e;
1486 TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e));
1487 if (e == nullptr) {
1488 // Skip acyclic merge nodes.
1489 continue;
1490 }
1491 Node* merge = n;
1492 // Note that here uses frame names instead of root frame names. In the
1493 // case of a nested while loop, each level of while loops can have merges
1494 // with different predicate instances, while the merge nodes on the same
1495 // level must have the same predicate instances.
1496 absl::string_view frame_name = control_flow_info_[merge->id()].frame_name;
1497 auto it = predicate_map_.find(TensorId(merge->name(), 0));
1498 Predicate* merge_pred = it->second;
1499 if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
1500 is_converged = false;
1501 VLOG(2) << "Running the optimistic mode on frame " << frame_name
1502 << " does not converge because node " << merge->name()
1503 << " cannot be mapped into the AndRecurrence form.";
1504 break;
1505 }
1506
1507 auto insert_result = frame_to_pred.insert({frame_name, merge_pred});
1508 if (!insert_result.second) {
1509 // If we have already seen this frame name, verify the predicate is the
1510 // same as the previously seen one's.
1511 Predicate* curr_andrec = merge_pred;
1512 Predicate* prev_andrec = insert_result.first->second;
1513 if (curr_andrec != prev_andrec) {
1514 is_converged = false;
1515 VLOG(2) << "Running the optimistic mode on frame " << frame_name
1516 << " does not converge. Seeing different Merge predicates: \n"
1517 << curr_andrec->ToString() << " and \n"
1518 << prev_andrec->ToString();
1519 break;
1520 }
1521 }
1522 }
1523
1524 // Clear the assigned predicates if the optimistic mode does not converge.
1525 if (!is_converged) {
1526 for (Node* n : topo) {
1527 for (int oid = 0; oid < n->num_outputs(); ++oid) {
1528 predicate_map_.erase(TensorId(n->name(), oid));
1529 }
1530 predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot));
1531 }
1532 }
1533
1534 if (success != nullptr) {
1535 *success = is_converged;
1536 }
1537 }
1538
1539 return Status::OK();
1540 }
1541
1542 StatusOr<DeadnessAnalysis::DeadnessPredicate>
GetPredicateFor(Node * n,int oidx) const1543 DeadnessAnalysisImpl::GetPredicateFor(Node* n, int oidx) const {
1544 auto it = predicate_map_.find(TensorId(n->name(), oidx));
1545 TF_RET_CHECK(it != predicate_map_.end())
1546 << "could not find " << TensorId(n->name(), oidx).ToString()
1547 << " in predicate map";
1548 return MakeDeadnessPredicate(it->second);
1549 }
1550
Print() const1551 void DeadnessAnalysisImpl::Print() const {
1552 std::vector<TensorId> tensor_ids;
1553 for (const auto& kv_pair : predicate_map_) {
1554 tensor_ids.push_back(kv_pair.first);
1555 }
1556
1557 std::sort(tensor_ids.begin(), tensor_ids.end());
1558
1559 for (TensorId tensor_id : tensor_ids) {
1560 auto it = predicate_map_.find(tensor_id);
1561 CHECK(it != predicate_map_.end()) << tensor_id.ToString();
1562 VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString();
1563 }
1564 }
1565
1566 } // namespace
1567
~DeadnessAnalysis()1568 DeadnessAnalysis::~DeadnessAnalysis() {}
1569
Run(const Graph & graph,std::unique_ptr<DeadnessAnalysis> * result)1570 /*static*/ Status DeadnessAnalysis::Run(
1571 const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
1572 std::unique_ptr<DeadnessAnalysisImpl> analysis(
1573 new DeadnessAnalysisImpl(&graph));
1574 TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true));
1575
1576 if (VLOG_IS_ON(2)) {
1577 analysis->Print();
1578 }
1579
1580 *result = std::move(analysis);
1581 return Status::OK();
1582 }
1583
1584 absl::flat_hash_map<TensorId, string, TensorId::Hasher>
PredicateMapAsString() const1585 DeadnessAnalysisImpl::PredicateMapAsString() const {
1586 absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
1587 for (const auto& kv_pair : predicate_map_) {
1588 CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
1589 }
1590 return result;
1591 }
1592
1593 namespace deadness_analysis_internal {
ComputePredicates(const Graph & graph,PredicateMapTy * out_predicate_map,bool enable_optimistic)1594 Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
1595 bool enable_optimistic) {
1596 DeadnessAnalysisImpl impl(&graph);
1597 TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic));
1598 *out_predicate_map = impl.PredicateMapAsString();
1599 return Status::OK();
1600 }
1601
1602 } // namespace deadness_analysis_internal
1603
DebugString(DeadnessPredicate predicate) const1604 string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const {
1605 return static_cast<Predicate*>(predicate.pred_)->ToString();
1606 }
1607
1608 } // namespace tensorflow
1609