1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/string_view.h"
23 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
24 #include "tensorflow/compiler/jit/xla_cluster_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/core/framework/tensor.pb.h"
27 #include "tensorflow/core/graph/algorithm.h"
28 #include "tensorflow/core/graph/control_flow.h"
29 #include "tensorflow/core/graph/graph_node_util.h"
30 #include "tensorflow/core/graph/tensor_id.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 
33 // ALGORITHM OVERVIEW
34 // ==================
35 //
36 // We map every output produced by each node in the TensorFlow graph (including
37 // control dependence) into an instance of the Predicate class.  Instances of
38 // Predicate denote logical formulas and mapping a node `n` to a predicate
39 // `pred` implies that `n` is live whenever `pred` is true.  Then we can deduce
40 // mismatching liveness in the inputs to node by comparing the predicate those
41 // inputs are mapped to.  The core logic of this pass resides in creating the
42 // map from TensorFlow nodes to predicates.
43 //
44 //
45 // MAPPING NODES TO PREDICATES, MODULO CYCLES
46 // ------------------------------------------
47 //
48 // If we ignore cycles for a moment, computing predicates is fairly
49 // straightforward.  We traverse the graph in a topological order, mapping each
50 // node to a predicate based on the predicates its inputs are mapped to.  For
51 // instance a Merge(X, Y) node will be mapped to OR(PredicateFor(X),
52 // PredicateFor(Y)).  Roughtly speaking, we abstractly interpret each node on
53 // the "liveness" domain, where values in the domain represent if a tensor
54 // carries a dead signal or not.
55 //
56 //
57 // DEALING WITH CYCLES
58 // -------------------
59 //
60 // We map Merge nodes that are the target of a backedge to AndRecurrence
61 // instances.  An AndRecurrence with start() = S and step() = X, printed as
62 // {S,&,X}, *roughly* represents the infinite list of predicates
63 // [S,S&X,S&X&X,S&X&X, ...].  So {S,&,X} can be used to represent the predicate
64 // for Merge in a graph like:
65 //
66 //     Init
67 //       |
68 //       v
69 //     Merge <-----------+
70 //       |               |
71 //       v               |
72 //      Incr             |
73 //       |               |
74 //       v               |
75 //      Switch <- Cond   |
76 //       |               |
77 //       v (oidx: 1)     |
78 //       |               |
79 //       +---------------+
80 //
81 // Where S is the predicate for Init and X is the predicate that asserts that
82 // Cond is true.  {S,&,X} states that Merge is live on the first "iteration" iff
83 // S is true, live on the second iteration iff "S&X" is true, live on the third
84 // iteration iff "S&X&X" is true etc.  There is a subtlety here, S&X&X would
85 // normally be equivalent to S&X which isn't quite what we want to represent.
86 // Instead we want {S,&,X} to denote the infinite list [S, S&X,
87 // S&X&X',S&X&X'&X'', ...] where X, X', X'' are predicates that assert Cond is
88 // true on iteration 0, 1, 2 respectively.  This is made more precise in the
89 // comment on the AndRecurrence class.
90 //
91 // The general algorithm that deals with cycles does two topological-order
92 // iterations over the graph.  On the first iteration it assigns a symbolic
93 // predicate to merge nodes with backedges.  On the second iteration it tries
94 // to pattern match the predicates for the backedges of these merges and infer
95 // an AndRecurrence for the merge.  In other words, we do a data flow analysis
96 // where the data-flow lattice has two elements, Symbolic and NonSymbolic with
97 // Symbolic > NonSymbolic.  The lattice has height = 2 so two iterations are
98 // sufficient to converge.
99 //
100 // We first do an optimistic analysis and, if it does not converge, we then fall
101 // back to a pessimistic analysis.  The optimistic analysis assigns the same
102 // symbolic predicate to all the merge nodes whose preceding enter nodes have
103 // the same frame name on the first iteration.  On the second iteration, if all
104 // the merge nodes are pattern matched into the same AndRecurrence predicate
105 // instance, the optimistic assignment of the same symbolic predicate is correct
106 // and the analyzed result is taken.
107 //
108 // Otherwise, if the optimistic analysis fails to converge, we then obtain the
109 // result by falling back to the pessimistic analysis which assigns a unique
110 // symbolic predicate to each merge on the first iteration.  We still use
111 // symbolic predicates for merges for which we can't pattern match on the
112 // backedge predicate.  This is conservatively correct.
113 
114 namespace tensorflow {
115 
116 namespace {
117 
118 using se::port::StatusOr;
119 
120 // Represents a logical predicate, used as described in the algorithm overview
121 // above.
122 class Predicate {
123  public:
124   enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol, kIntSymbol };
125 
126   virtual string ToString() const = 0;
127 
128   // An ID assigned to the Predicate at construction time.  Conceptually like a
129   // pointer, except that it is stable across runs.
id() const130   int64 id() const { return id_; }
131 
132   virtual absl::Span<Predicate* const> GetOperands() const = 0;
133 
134   virtual Kind kind() const = 0;
~Predicate()135   virtual ~Predicate() {}
136 
137   // Invokes func on p and on all of its operands recursively.  Does not invoke
138   // `func` on the same Predicate instance twice.  Aborts the search if `func`
139   // returns true.
140   template <typename FunctionTy>
141   static void Visit(Predicate* p, const FunctionTy& func);
142 
143  protected:
Predicate(int64 id)144   explicit Predicate(int64 id) : id_(id) {}
145 
146  private:
147   const int64 id_;
148 
149   TF_DISALLOW_COPY_AND_ASSIGN(Predicate);
150 };
151 
152 // Represents a logical conjunction of a set of predicates.
153 class AndPredicate : public Predicate {
154  public:
AndPredicate(int64 id,std::vector<Predicate * > operands)155   explicit AndPredicate(int64 id, std::vector<Predicate*> operands)
156       : Predicate(id), operands_(std::move(operands)) {}
157 
ToString() const158   string ToString() const override {
159     if (operands().empty()) {
160       return "#true";
161     }
162 
163     std::vector<string> operands_str;
164     std::transform(operands().begin(), operands().end(),
165                    std::back_inserter(operands_str),
166                    [](Predicate* pred) { return pred->ToString(); });
167 
168     return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
169   }
170 
kind() const171   Kind kind() const override { return Kind::kAnd; }
172 
GetOperands() const173   absl::Span<Predicate* const> GetOperands() const override {
174     return operands_;
175   }
operands() const176   absl::Span<Predicate* const> operands() const { return operands_; }
177 
178  private:
179   std::vector<Predicate*> operands_;
180 };
181 
182 // Represents a logical disjunction of a set of predicates.
183 class OrPredicate : public Predicate {
184  public:
OrPredicate(int64 id,std::vector<Predicate * > operands)185   explicit OrPredicate(int64 id, std::vector<Predicate*> operands)
186       : Predicate(id), operands_(std::move(operands)) {}
187 
ToString() const188   string ToString() const override {
189     if (operands().empty()) {
190       return "#false";
191     }
192 
193     std::vector<string> operands_str;
194     std::transform(operands().begin(), operands().end(),
195                    std::back_inserter(operands_str),
196                    [](Predicate* pred) { return pred->ToString(); });
197 
198     return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
199   }
200 
kind() const201   Kind kind() const override { return Kind::kOr; }
GetOperands() const202   absl::Span<Predicate* const> GetOperands() const override {
203     return operands_;
204   }
operands() const205   absl::Span<Predicate* const> operands() const { return operands_; }
206 
207  private:
208   std::vector<Predicate*> operands_;
209 };
210 
211 // Represents a logical negation of a set of predicates.
212 class NotPredicate : public Predicate {
213  public:
NotPredicate(int64 id,Predicate * operand)214   explicit NotPredicate(int64 id, Predicate* operand)
215       : Predicate(id), operands_({operand}) {}
216 
ToString() const217   string ToString() const override {
218     return absl::StrCat("~", operand()->ToString());
219   }
220 
kind() const221   Kind kind() const override { return Kind::kNot; }
operand() const222   Predicate* operand() const { return operands_[0]; }
GetOperands() const223   absl::Span<Predicate* const> GetOperands() const override {
224     return operands_;
225   }
226 
227  private:
228   std::array<Predicate*, 1> operands_;
229 };
230 
231 // Represents the liveness of an induction variable.  For users inside the loop
232 // this represents the "current" liveness of the induction variable.  For users
233 // outside the loop it represents the "last" liveness of the induction variable.
234 //
235 // More concretely, an and recurrence {S,&,X}<loop> represents the liveness of V
236 // in the following graph:
237 //
238 //   V = Merge(S', V_NextIt)
239 //   V = Op(V, X')
240 //   V_NextIt = NextIteration(V)
241 //
242 // where Predicate(S') = S and Predicate(X') = X.
243 //
244 // `X` may contain symbolic predicates and the operations corresponding to these
245 // symbolic predicates are either in frame `loop` or outside it.  The symbols
246 // that are inside frame `loop` are loop variant (i.e. can have different
247 // liveness in each loop iteration) and the symbols that are outside frame
248 // `loop` are loop invariant (i.e. have the same liveness across all
249 // iterations).
250 class AndRecurrencePredicate : public Predicate {
251  public:
AndRecurrencePredicate(int64 id,Predicate * start,Predicate * step,std::vector<string> frame)252   explicit AndRecurrencePredicate(int64 id, Predicate* start, Predicate* step,
253                                   std::vector<string> frame)
254       : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {}
255 
start() const256   Predicate* start() const { return operands_[0]; }
step() const257   Predicate* step() const { return operands_[1]; }
frame() const258   absl::Span<const string> frame() const { return frame_; }
259 
ToString() const260   string ToString() const override {
261     return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
262                         "}<", absl::StrJoin(frame(), ";"), ">");
263   }
264 
kind() const265   Kind kind() const override { return Kind::kAndRecurrence; }
266 
GetOperands() const267   absl::Span<Predicate* const> GetOperands() const override {
268     return operands_;
269   }
270 
271  private:
272   std::array<Predicate*, 2> operands_;
273   std::vector<string> frame_;
274 };
275 
276 // Represents an uninterpreted symbol in a logical predicate.
277 //
278 // Two predicates are equivalent iff they are equivalent for all assignments to
279 // the symbols contained in them, i.e. predicates are forall qualified over
280 // symbols.
281 class SymbolPredicate : public Predicate {
282  public:
SymbolPredicate(int64 id,TensorId tensor_id,bool must_be_true)283   explicit SymbolPredicate(int64 id, TensorId tensor_id, bool must_be_true)
284       : Predicate(id),
285         tensor_id_(std::move(tensor_id)),
286         must_be_true_(must_be_true) {}
287 
ToString() const288   string ToString() const override {
289     return must_be_true() ? absl::StrCat("*", tensor_id_.ToString())
290                           : tensor_id_.ToString();
291   }
292 
kind() const293   Kind kind() const override { return Kind::kSymbol; }
GetOperands() const294   absl::Span<Predicate* const> GetOperands() const override { return {}; }
295 
296   // If `must_be_true()` is true this SymbolPredicate represents the proposition
297   // "tensor_id() is live and evaluates to true".
298   //
299   // If `must_be_true()` is false then this SymbolPredicate represents the
300   // proposition "tensor_id() is live (and may evaluate to any value)"
tensor_id() const301   TensorId tensor_id() const { return tensor_id_; }
must_be_true() const302   bool must_be_true() const { return must_be_true_; }
303 
304  private:
305   TensorId tensor_id_;
306   bool must_be_true_;
307 };
308 
309 // Represents an uninterpreted symbol in a logical predicate.
310 //
311 // Two predicates are equivalent iff they are equivalent for all assignments to
312 // the symbols contained in them, i.e. predicates are forall qualified over
313 // symbols.
314 class IntSymbolPredicate : public Predicate {
315  public:
IntSymbolPredicate(int64 id,TensorId tensor_id,absl::optional<int> must_have_value)316   explicit IntSymbolPredicate(int64 id, TensorId tensor_id,
317                               absl::optional<int> must_have_value)
318       : Predicate(id),
319         tensor_id_(std::move(tensor_id)),
320         must_have_value_(must_have_value) {}
321 
ToString() const322   string ToString() const override {
323     return must_have_value().has_value()
324                ? absl::StrCat(tensor_id_.ToString(), "=", *must_have_value_)
325                : tensor_id_.ToString();
326   }
327 
kind() const328   Kind kind() const override { return Kind::kIntSymbol; }
GetOperands() const329   absl::Span<Predicate* const> GetOperands() const override { return {}; }
330 
331   // If `must_have_value().has_value()` is true, then this IntSymbolPredicate
332   // represents the proposition "tensor_id() is live and evaluates to
333   // `*must_have_value()`".
334   //
335   // If `must_have_value().has_value()` is false, then this IntSymbolPredicate
336   // represents the proposition "tensor_id() is live (and may evaluate to any
337   // value)".
tensor_id() const338   TensorId tensor_id() const { return tensor_id_; }
must_have_value() const339   const absl::optional<int>& must_have_value() const {
340     return must_have_value_;
341   }
342 
343  private:
344   TensorId tensor_id_;
345   absl::optional<int> must_have_value_;
346 };
347 
348 template <typename FunctionTy>
Visit(Predicate * p,const FunctionTy & func)349 /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
350   absl::flat_hash_set<Predicate*> visited;
351   std::vector<Predicate*> stack;
352 
353   stack.push_back(p);
354   visited.insert(p);
355 
356   while (!stack.empty()) {
357     Predicate* current = stack.back();
358     stack.pop_back();
359     bool done = func(current);
360     if (done) {
361       return;
362     }
363     for (Predicate* op : current->GetOperands()) {
364       if (visited.insert(op).second) {
365         stack.push_back(op);
366       }
367     }
368   }
369 }
370 
371 // Creates and owns Predicate instances.  Simplifies predicates as it creates
372 // them.
373 class PredicateFactory {
374  public:
MakeAndPredicate(absl::Span<Predicate * const> operands)375   Predicate* MakeAndPredicate(absl::Span<Predicate* const> operands) {
376     return MakeAndOrImpl(operands, /*is_and=*/true);
377   }
378 
MakeOrPredicate(absl::Span<Predicate * const> operands)379   Predicate* MakeOrPredicate(absl::Span<Predicate* const> operands) {
380     return MakeAndOrImpl(operands, /*is_and=*/false);
381   }
382 
MakeNotPredicate(Predicate * pred)383   Predicate* MakeNotPredicate(Predicate* pred) {
384     auto it = make_not_predicate_cache_.find(pred);
385     if (it != make_not_predicate_cache_.end()) {
386       return it->second;
387     }
388 
389     Predicate* result = MakeNotPredicateImpl(pred);
390 
391     bool insert_successful =
392         make_not_predicate_cache_.insert({pred, result}).second;
393     (void)insert_successful;
394     DCHECK(insert_successful);
395 
396     return result;
397   }
398 
MakeAndRecurrencePredicate(Predicate * start,Predicate * step,std::vector<string> frame)399   Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step,
400                                         std::vector<string> frame) {
401     SignatureForAndRec signature(start, step, std::move(frame));
402     auto it = interned_and_rec_instances_.find(signature);
403     if (it != interned_and_rec_instances_.end()) {
404       return it->second.get();
405     }
406 
407     std::unique_ptr<Predicate> new_pred = Make<AndRecurrencePredicate>(
408         std::get<0>(signature), std::get<1>(signature), std::get<2>(signature));
409     Predicate* new_pred_ptr = new_pred.get();
410     bool inserted =
411         interned_and_rec_instances_.emplace(signature, std::move(new_pred))
412             .second;
413     (void)inserted;
414     DCHECK(inserted);
415     return new_pred_ptr;
416   }
417 
MakeSymbolPredicate(Node * node,int output_idx,bool must_be_true,Predicate ** predicate)418   Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true,
419                              Predicate** predicate) {
420     TensorId tensor_id(node->name(), output_idx);
421 
422     bool is_boolean_tensor =
423         BaseType(node->output_type(tensor_id.index())) == DT_BOOL;
424     TF_RET_CHECK(!must_be_true || is_boolean_tensor);
425 
426     if (node->type_string() == "Const" && must_be_true) {
427       const TensorProto* proto = nullptr;
428       TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto));
429 
430       Tensor tensor(proto->dtype());
431       TF_RET_CHECK(tensor.FromProto(*proto));
432 
433       *predicate = tensor.scalar<bool>()() ? MakeTrue() : MakeFalse();
434       return Status::OK();
435     }
436 
437     SignatureForSymbol signature = {tensor_id, must_be_true};
438     auto it = interned_symbol_instances_.find(signature);
439     if (it == interned_symbol_instances_.end()) {
440       std::unique_ptr<Predicate> new_pred =
441           Make<SymbolPredicate>(tensor_id, must_be_true);
442       Predicate* new_pred_ptr = new_pred.get();
443       interned_symbol_instances_.emplace(std::move(signature),
444                                          std::move(new_pred));
445       *predicate = new_pred_ptr;
446     } else {
447       *predicate = it->second.get();
448     }
449 
450     return Status::OK();
451   }
452 
MakeSymbolPredicate(Node * node,int output_idx,absl::optional<int> must_have_value,Predicate ** predicate)453   Status MakeSymbolPredicate(Node* node, int output_idx,
454                              absl::optional<int> must_have_value,
455                              Predicate** predicate) {
456     TensorId tensor_id(node->name(), output_idx);
457 
458     TF_RET_CHECK(BaseType(node->output_type(tensor_id.index())) == DT_INT32);
459 
460     if (must_have_value.has_value() && node->type_string() == "Const") {
461       const TensorProto* proto = nullptr;
462       TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto));
463 
464       Tensor tensor(proto->dtype());
465       TF_RET_CHECK(tensor.FromProto(*proto));
466 
467       *predicate = tensor.scalar<int32>()() == *must_have_value ? MakeTrue()
468                                                                 : MakeFalse();
469       return Status::OK();
470     }
471     SignatureForIntSymbol signature = {tensor_id, must_have_value};
472     auto it = interned_int_symbol_instances_.find(signature);
473     if (it == interned_int_symbol_instances_.end()) {
474       std::unique_ptr<Predicate> new_pred =
475           Make<IntSymbolPredicate>(tensor_id, must_have_value);
476       Predicate* new_pred_ptr = new_pred.get();
477       interned_int_symbol_instances_.emplace(std::move(signature),
478                                              std::move(new_pred));
479       *predicate = new_pred_ptr;
480     } else {
481       *predicate = it->second.get();
482     }
483 
484     return Status::OK();
485   }
486 
MakeTrue()487   Predicate* MakeTrue() { return MakeAndPredicate({}); }
MakeFalse()488   Predicate* MakeFalse() { return MakeOrPredicate({}); }
489 
~PredicateFactory()490   ~PredicateFactory() {
491     DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?";
492   }
493 
494  private:
MakeNotPredicateImpl(Predicate * pred)495   Predicate* MakeNotPredicateImpl(Predicate* pred) {
496     IncrementStackDepth stack_frame(this);
497     if (!stack_frame.HasOverflowed()) {
498       if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) {
499         return simplified;
500       }
501 
502       // ~~A => A
503       if (auto* not_pred = dynamic_cast<NotPredicate*>(pred)) {
504         return not_pred->operand();
505       }
506     }
507 
508     SignatureForNot signature = pred;
509     auto it = interned_not_instances_.find(signature);
510     if (it == interned_not_instances_.end()) {
511       std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
512       Predicate* new_pred_ptr = new_pred.get();
513       interned_not_instances_.emplace(signature, std::move(new_pred));
514       return new_pred_ptr;
515     } else {
516       return it->second.get();
517     }
518   }
519 
SimplifyUsingDeMorgan(Predicate * pred)520   Predicate* SimplifyUsingDeMorgan(Predicate* pred) {
521     // ~(A & B & C & ...) => ~A | ~B | ~C | ~...
522     // ~(A | B | C | ...) -> ~A & ~B & ~C & ~...
523     Predicate::Kind kind = pred->kind();
524 
525     if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) {
526       std::vector<Predicate*> new_operands;
527       absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands),
528                         [&](Predicate* p) { return MakeNotPredicate(p); });
529       return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands)
530                                           : MakeOrPredicate(new_operands);
531     }
532 
533     return nullptr;
534   }
535 
536   template <typename PredicateT, typename... Args>
Make(Args &&...args)537   std::unique_ptr<Predicate> Make(Args&&... args) {
538     // If we ever expose the Predicate class outside this .cc file then we may
539     // want to make this hard to misuse (by accidentally passing in an arbitrary
540     // integer to the Predicate constructor for instance).
541     return std::unique_ptr<PredicateT>(
542         new PredicateT(id_counter_++, std::forward<Args>(args)...));
543   }
544 
545   Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
546   Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops,
547                                Predicate::Kind pred_kind);
548 
549   // Predicate instances are interned, meaning that there is only a single
550   // instance of a Predicate object with a given content.  This makes checking
551   // for structural equality super-cheap -- we can just compare pointers.
552   //
553   // We intern predicates by maintaining a map from the content of a Predicate
554   // to the only instance of said predicate we allow to exist in the
555   // interned_and_or_instances_, interned_not_instances_ and
556   // interned_symbol_instances_ fields.  These maps also double up as storage
557   // for the owning pointers to predicate instances.
558 
559   using SignatureForAndOr =
560       std::pair<Predicate::Kind, absl::Span<Predicate* const>>;
561   using SignatureForNot = Predicate*;
562   using SignatureForAndRec =
563       std::tuple<Predicate*, Predicate*, std::vector<string>>;
564   using SignatureForSymbol = std::pair<SafeTensorId, bool>;
565   using SignatureForIntSymbol = std::pair<SafeTensorId, absl::optional<int32>>;
566 
567   struct HashSignatureForAndOr {
operator ()tensorflow::__anonab73ef130111::PredicateFactory::HashSignatureForAndOr568     size_t operator()(const SignatureForAndOr& signature) const {
569       size_t hash = ::tensorflow::hash<Predicate::Kind>()(signature.first);
570       for (Predicate* p : signature.second) {
571         hash = Hash64Combine(hash, ::tensorflow::hash<Predicate*>()(p));
572       }
573       return hash;
574     }
575   };
576 
577   struct HashSignatureForSymbol {
operator ()tensorflow::__anonab73ef130111::PredicateFactory::HashSignatureForSymbol578     size_t operator()(const SignatureForSymbol& signature) const {
579       return Hash64Combine(SafeTensorId::Hasher()(signature.first),
580                            ::tensorflow::hash<bool>()(signature.second));
581     }
582   };
583 
584   struct HashSignatureForIntSymbol {
operator ()tensorflow::__anonab73ef130111::PredicateFactory::HashSignatureForIntSymbol585     size_t operator()(const SignatureForIntSymbol& signature) const {
586       return Hash64Combine(
587           SafeTensorId::Hasher()(signature.first),
588           Hash64Combine(
589               ::tensorflow::hash<bool>()(signature.second.has_value()),
590               ::tensorflow::hash<int32>()(
591                   signature.second.has_value() ? *signature.second : 0)));
592     }
593   };
594 
595   // Used to limit recursion to avoid blowing up the stack and cap compile time.
596   class IncrementStackDepth {
597    public:
IncrementStackDepth(PredicateFactory * parent)598     explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) {
599       parent_->stack_depth_++;
600     }
601 
HasOverflowed() const602     bool HasOverflowed() const {
603       const int kMaxStackDepth = 8;
604       return parent_->stack_depth_ >= kMaxStackDepth;
605     }
606 
~IncrementStackDepth()607     ~IncrementStackDepth() { parent_->stack_depth_--; }
608 
609    private:
610     PredicateFactory* parent_;
611   };
612 
613   // A cache for the MakeNotPredicate function.
614   //
615   // NB! This is *not* the same as `interned_not_instances_`.
616   // `interned_not_instances_` maps ensures pointer identity for `NotPredicate`
617   // instances, i.e., it ensures there at most one instance of Not(predicate)
618   // for any given predicate whereas `make_not_predicate_cache_` simply caches
619   // the result of the `MakeNotPredicate` function.  The values in
620   // `interned_not_instances_` are always instance of `NotPredicate` whereas the
621   // values in `make_not_predicate_cache_` may not be (for instance it will map
622   // Not(Not(A)) to A).
623   absl::flat_hash_map<Predicate*, Predicate*> make_not_predicate_cache_;
624 
625   absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
626                       HashSignatureForAndOr>
627       interned_and_or_instances_;
628   absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>>
629       interned_not_instances_;
630   absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>>
631       interned_and_rec_instances_;
632   absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
633                       HashSignatureForSymbol>
634       interned_symbol_instances_;
635   absl::flat_hash_map<SignatureForIntSymbol, std::unique_ptr<Predicate>,
636                       HashSignatureForIntSymbol>
637       interned_int_symbol_instances_;
638   int64 id_counter_ = 0;
639   int stack_depth_ = 0;
640 };
641 
MakeInternedAndOr(std::vector<Predicate * > simplified_ops,Predicate::Kind pred_kind)642 Predicate* PredicateFactory::MakeInternedAndOr(
643     std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
644   std::stable_sort(
645       simplified_ops.begin(), simplified_ops.end(),
646       [](Predicate* a, Predicate* b) { return a->id() < b->id(); });
647 
648   auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
649   if (it != interned_and_or_instances_.end()) {
650     return it->second.get();
651   }
652 
653   simplified_ops.shrink_to_fit();
654   // NB!  Because we'll use a non-owning reference to simplified_ops in the
655   // key for interned_and_or_instances_ we need to be careful to std::move()
656   // it all the way through.
657   absl::Span<Predicate* const> operands_slice = simplified_ops;
658   std::unique_ptr<Predicate> new_pred =
659       pred_kind == Predicate::Kind::kAnd
660           ? Make<AndPredicate>(std::move(simplified_ops))
661           : Make<OrPredicate>(std::move(simplified_ops));
662 
663   Predicate* new_pred_ptr = new_pred.get();
664   interned_and_or_instances_.emplace(
665       SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred));
666   return new_pred_ptr;
667 }
668 
669 // Common code to create AndPredicate or OrPredicate instances.
MakeAndOrImpl(absl::Span<Predicate * const> operands,bool is_and)670 Predicate* PredicateFactory::MakeAndOrImpl(
671     absl::Span<Predicate* const> operands, bool is_and) {
672   Predicate::Kind pred_kind =
673       is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
674 
675   IncrementStackDepth stack_frame(this);
676   if (stack_frame.HasOverflowed()) {
677     return MakeInternedAndOr(
678         std::vector<Predicate*>(operands.begin(), operands.end()), pred_kind);
679   }
680 
681   Predicate::Kind other_pred_kind =
682       is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
683   absl::flat_hash_set<Predicate*> simplified_ops_set;
684   std::vector<Predicate*> simplified_ops;
685   for (Predicate* op : operands) {
686     // Simplify A&A => A and  A|A => A.
687     if (!simplified_ops_set.insert(op).second) {
688       continue;
689     }
690 
691     if (op->kind() == pred_kind) {
692       // "Inline" the operands of an inner And/Or into the parent And/Or.
693       for (Predicate* subop : op->GetOperands()) {
694         if (simplified_ops_set.insert(subop).second) {
695           simplified_ops.push_back(subop);
696         }
697       }
698     } else {
699       simplified_ops.push_back(op);
700     }
701   }
702 
703   if (simplified_ops.size() == 1) {
704     return simplified_ops[0];
705   }
706 
707   // Simplify "A&~A=>False" and "A|~A=>True".
708   absl::flat_hash_set<Predicate*> negated_ops;
709   for (Predicate* op : simplified_ops) {
710     if (negated_ops.count(op)) {
711       // Simple case:
712       //
713       //   A & ~A & ... == False
714       //   A | ~A | ... == True
715       return is_and ? MakeFalse() : MakeTrue();
716     }
717 
718     Predicate* negated_op = MakeNotPredicate(op);
719     if (negated_op->kind() == pred_kind) {
720       // Slightly more complicated case:
721       //
722       //   (~A | ~B | ~C) & A & B & C & ... ==
723       //   ~(A & B & C) & (A & B & C) & ... == False
724       //
725       //   (~A & ~B & ~C) | A | B | C | ... ==
726       //   ~(A | B | C) | (A | B | C) | ... == True
727       if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) {
728             return simplified_ops_set.contains(p);
729           })) {
730         return is_and ? MakeFalse() : MakeTrue();
731       }
732     }
733     negated_ops.insert(negated_op);
734   }
735 
736   // Simplify {S,&,X} & ~X & ... => S & ...
737   if (is_and) {
738     absl::flat_hash_set<Predicate*> to_remove;
739     std::vector<Predicate*> to_add;
740     for (Predicate* op : simplified_ops) {
741       if (op->kind() == Predicate::Kind::kAndRecurrence) {
742         auto* and_rec = static_cast<AndRecurrencePredicate*>(op);
743         if (negated_ops.contains(and_rec->step())) {
744           // Remove and_rec and ~X and insert S.  Note that checking the
745           // existence of ~X through negated_ops is sufficient since it makes
746           // sure the predicate is in the input operands.  It does not need to
747           // be in simplified_ops if it was already cancelled out.
748           to_remove.insert(and_rec);
749           to_remove.insert(MakeNotPredicate(and_rec->step()));
750           to_add.push_back(and_rec->start());
751         }
752       }
753     }
754     auto it = simplified_ops.begin();
755     while (it != simplified_ops.end()) {
756       if (to_remove.contains(*it)) {
757         it = simplified_ops.erase(it);
758       } else {
759         ++it;
760       }
761     }
762     simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end());
763   }
764 
765   // If all ops contain the same subop, then factor it out thanks to the
766   // distributive property. Such as:
767   // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
768   // - (A | B) & (A | C) & (A | D) => A | (B & C & D)
769   //
770   // First find any predicates contained in all subops.
771   std::vector<Predicate*> common_inner_operands;
772   absl::flat_hash_set<Predicate*> common_inner_operands_set;
773   for (Predicate* op : simplified_ops) {
774     if (op->kind() != other_pred_kind) {
775       common_inner_operands.clear();
776       break;
777     }
778 
779     if (common_inner_operands.empty()) {
780       common_inner_operands.insert(common_inner_operands.end(),
781                                    op->GetOperands().begin(),
782                                    op->GetOperands().end());
783     } else {
784       common_inner_operands.clear();
785       absl::c_copy_if(op->GetOperands(),
786                       std::back_inserter(common_inner_operands),
787                       [&](Predicate* sub_op) {
788                         return common_inner_operands_set.count(sub_op) == 1;
789                       });
790     }
791     if (common_inner_operands.empty()) break;
792     common_inner_operands_set.clear();
793     common_inner_operands_set.insert(common_inner_operands.begin(),
794                                      common_inner_operands.end());
795   }
796 
797   if (common_inner_operands.empty()) {
798     return MakeInternedAndOr(std::move(simplified_ops), pred_kind);
799   }
800 
801   // For all predicates that can be factored out, remove them and recreate the
802   // subops.
803   std::vector<Predicate*> factored_ops;
804   for (Predicate* op : simplified_ops) {
805     std::vector<Predicate*> new_sub_op_ops;
806     absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops),
807                     [&](Predicate* sub_op) {
808                       return std::find(common_inner_operands.begin(),
809                                        common_inner_operands.end(),
810                                        sub_op) == common_inner_operands.end();
811                     });
812     factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and));
813   }
814 
815   Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and);
816   std::vector<Predicate*> outer_ops;
817   outer_ops.push_back(new_inner_op);
818   outer_ops.insert(outer_ops.end(), common_inner_operands.begin(),
819                    common_inner_operands.end());
820   return MakeAndOrImpl(outer_ops, !is_and);
821 }
822 
823 class DeadnessAnalysisImpl : public DeadnessAnalysis {
824  public:
DeadnessAnalysisImpl(const Graph * graph)825   explicit DeadnessAnalysisImpl(const Graph* graph)
826       : graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
827 
828   Status Populate(bool enable_optimistic);
829   Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
830                        bool* success);
831   StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
832       Node* n, int oidx) const override;
833   void Print() const override;
834   absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
835       const;
836 
837  private:
838   enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
839 
840   Status GetInputPreds(Node* n, EdgeKind edge_kind,
841                        std::vector<Predicate*>* result);
842 
843   // Sets the predicate for output `output_idx` of `n` to `pred`.  Sets the i'th
844   // bit of `should_revisit` if `pred` is different from the current predicate
845   // for the `output_idx` output of `n`.
SetPredicate(Node * n,int output_idx,Predicate * pred,std::vector<bool> * should_revisit)846   void SetPredicate(Node* n, int output_idx, Predicate* pred,
847                     std::vector<bool>* should_revisit) {
848     auto insert_result =
849         predicate_map_.insert({TensorId(n->name(), output_idx), pred});
850     if (!insert_result.second && insert_result.first->second != pred) {
851       VLOG(4) << "For " << n->name() << ":" << output_idx << " from "
852               << insert_result.first->second->ToString() << " "
853               << insert_result.first->second << " to " << pred->ToString()
854               << " " << pred;
855       insert_result.first->second = pred;
856       if (should_revisit != nullptr) {
857         for (const Edge* e : n->out_edges()) {
858           (*should_revisit)[e->dst()->id()] = true;
859         }
860       }
861     }
862   }
863 
SetPredicate(Node * n,absl::Span<const int> output_idxs,Predicate * pred,std::vector<bool> * should_revisit)864   void SetPredicate(Node* n, absl::Span<const int> output_idxs, Predicate* pred,
865                     std::vector<bool>* should_revisit) {
866     for (int output_idx : output_idxs) {
867       SetPredicate(n, output_idx, pred, should_revisit);
868     }
869   }
870 
871   Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
872   Status HandleMerge(Node* n, std::vector<bool>* should_revisit,
873                      bool use_optimistic_mode);
874   Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
875   Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
876   Status HandleNode(Node* n, std::vector<bool>* should_revisit,
877                     bool use_optimistic_mode = false);
878 
879   Status GetFrameBasedTopologicalOrder(std::vector<Node*>* order);
880 
IsRootEnter(const Node * n) const881   bool IsRootEnter(const Node* n) const {
882     return IsEnter(n) && control_flow_info_[n->id()].parent_frame->IsSource();
883   }
884 
IsRootExit(const Node * n) const885   bool IsRootExit(const Node* n) const {
886     return IsExit(n) && control_flow_info_[n->id()].parent_frame->IsSource();
887   }
888 
889   const Graph& graph_;
890   absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
891   PredicateFactory predicate_factory_;
892   std::vector<ControlFlowInfo> control_flow_info_;
893   bool vlog_;
894   absl::flat_hash_map<absl::string_view, Node*> frame_to_merge_node_;
895 };
896 
InputEdgeToTensorId(const Edge * e)897 TensorId InputEdgeToTensorId(const Edge* e) {
898   return TensorId(e->src()->name(), e->src_output());
899 }
900 
GetInputPreds(Node * n,DeadnessAnalysisImpl::EdgeKind edge_kind,std::vector<Predicate * > * result)901 Status DeadnessAnalysisImpl::GetInputPreds(
902     Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind,
903     std::vector<Predicate*>* result) {
904   result->clear();
905   for (const Edge* in_edge : n->in_edges()) {
906     bool should_process =
907         edge_kind == EdgeKind::kDataAndControl ||
908         (in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) ||
909         (!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly);
910 
911     if (should_process) {
912       auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
913       if (it == predicate_map_.end()) {
914         GraphCycles graph_cycles;
915         TF_RETURN_IF_ERROR(
916             CreateCycleDetectionGraph(&graph_, &graph_cycles).status());
917 
918         // If we didn't return with an error above then the graph is probably
919         // fine and we have a bug in deadness analysis.
920         return errors::Internal("Could not find input ", in_edge->DebugString(),
921                                 " to ", n->name(),
922                                 " when visiting the graph in post-order.  Most "
923                                 "likely indicates a bug in deadness analysis.");
924       }
925       result->push_back(it->second);
926     }
927   }
928   return Status::OK();
929 }
930 
HandleSwitch(Node * n,std::vector<bool> * should_revisit)931 Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
932                                           std::vector<bool>* should_revisit) {
933   std::vector<Predicate*> input_preds;
934   TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
935   const Edge* pred_edge;
936   TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
937 
938   if (n->type_string() != "_SwitchN") {  // bool pred branch selector.
939     Predicate* true_switch;
940     TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
941         pred_edge->src(), pred_edge->src_output(),
942         /*must_be_true=*/true, &true_switch));
943 
944     Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
945 
946     // Output 0 is alive iff all inputs are alive and the condition is false.
947     input_preds.push_back(false_switch);
948     SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
949                  should_revisit);
950     input_preds.pop_back();
951 
952     // Output 1 is alive iff all inputs are alive and the condition is true.
953     input_preds.push_back(true_switch);
954     SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
955                  should_revisit);
956     input_preds.pop_back();
957   } else {  // N-way switch case. Exactly one of N branches is alive.
958     Predicate* branch_pred;
959     for (int i = 0; i < n->num_outputs() - 1; i++) {
960       TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
961           pred_edge->src(), pred_edge->src_output(),
962           /*must_have_value=*/absl::optional<int32>(i), &branch_pred));
963       input_preds.push_back(branch_pred);
964       SetPredicate(n, i, predicate_factory_.MakeAndPredicate(input_preds),
965                    should_revisit);
966       input_preds.pop_back();
967       input_preds.push_back(predicate_factory_.MakeNotPredicate(branch_pred));
968     }
969     // The default (last) branch does not need its own symbol, is simply the
970     // nor of all other branches.
971     SetPredicate(n, n->num_outputs() - 1,
972                  predicate_factory_.MakeAndPredicate(input_preds),
973                  should_revisit);
974   }
975 
976   // Control is alive iff all inputs are alive.
977   SetPredicate(n, Graph::kControlSlot,
978                predicate_factory_.MakeAndPredicate(input_preds),
979                should_revisit);
980 
981   return Status::OK();
982 }
983 
984 namespace {
CreateMultipleNextIterationInputsError(Node * merge)985 Status CreateMultipleNextIterationInputsError(Node* merge) {
986   std::vector<string> backedges;
987   for (const Edge* backedge : merge->in_edges()) {
988     if (backedge->src()->IsNextIteration()) {
989       backedges.push_back(absl::StrCat("  ", SummarizeNode(*backedge->src())));
990     }
991   }
992   return errors::InvalidArgument(
993       "Multiple NextIteration inputs to merge node ",
994       FormatNodeForError(*merge), ": \n", absl::StrJoin(backedges, "\n"),
995       "\nMerge nodes can have at most one incoming NextIteration edge.");
996 }
997 
FindUniqueBackedge(Node * merge,const Edge ** result)998 Status FindUniqueBackedge(Node* merge, const Edge** result) {
999   *result = nullptr;
1000   CHECK(merge->IsMerge());
1001   for (const Edge* e : merge->in_edges()) {
1002     if (e->src()->IsNextIteration()) {
1003       if (*result != nullptr) {
1004         return CreateMultipleNextIterationInputsError(merge);
1005       }
1006       *result = e;
1007     }
1008   }
1009   return Status::OK();
1010 }
1011 
1012 // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step
1013 // does not contain `symbolic_predicate` as an inner (not top-level) operand
1014 // then returns `Step`.  Otherwise returns nullptr.
DeduceStepPredicate(PredicateFactory * predicate_factory,Predicate * symbolic_predicate,Predicate * backedge_predicate)1015 Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory,
1016                                Predicate* symbolic_predicate,
1017                                Predicate* backedge_predicate) {
1018   CHECK(dynamic_cast<SymbolPredicate*>(symbolic_predicate));
1019   if (backedge_predicate->kind() != Predicate::Kind::kAnd) {
1020     return nullptr;
1021   }
1022 
1023   std::vector<Predicate*> and_ops;
1024   absl::Span<Predicate* const> recurrent_pred_ops =
1025       backedge_predicate->GetOperands();
1026 
1027   bool found_sym = false;
1028   for (Predicate* and_op : recurrent_pred_ops) {
1029     // We want the `symbol_predicate` to be the one of the operands of
1030     // `backedge_predicate`,
1031     if (and_op == symbolic_predicate) {
1032       found_sym = true;
1033       continue;
1034     }
1035 
1036     // but we don't want it to be present anywhere else in the formula.  E.g. we
1037     // don't want the recurrent predicate to be
1038     // symbol_predicate&(X|symbol_predicate).
1039     bool found_sym_as_inner_operand = false;
1040     auto has_self_as_inner_operand = [&](Predicate* p) {
1041       if (p == symbolic_predicate) {
1042         found_sym_as_inner_operand = true;
1043         return true;  // Stop searching, we're done.
1044       }
1045 
1046       // Continue searching.
1047       return false;
1048     };
1049 
1050     Predicate::Visit(and_op, has_self_as_inner_operand);
1051     if (found_sym_as_inner_operand) {
1052       return nullptr;
1053     }
1054     and_ops.push_back(and_op);
1055   }
1056 
1057   return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr;
1058 }
1059 
GetFullFrame(const Node * n,absl::Span<const ControlFlowInfo> cfi_infos,std::vector<string> * frame)1060 Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
1061                     std::vector<string>* frame) {
1062   int depth = 0;
1063   for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource();
1064        n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) {
1065     frame->push_back(cfi_iter->frame_name);
1066 
1067     if (depth++ > 5000) {
1068       return errors::Internal(
1069           "Frame of depth > 5000:  Probably malformed graph or a bug in "
1070           "BuildControlFlowInfo");
1071     }
1072   }
1073 
1074   return Status::OK();
1075 }
1076 
1077 // If the node is inside some frames, get the name of the outermost non-empty
1078 // frame.  Otherwise, get an empty frame name.
GetRootFrame(const Node * n,absl::Span<const ControlFlowInfo> cfi_infos,absl::string_view * frame)1079 Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
1080                     absl::string_view* frame) {
1081   int depth = 0;
1082   const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
1083   while (!cfi_iter->parent_frame->IsSource()) {
1084     n = cfi_iter->parent_frame;
1085     cfi_iter = &cfi_infos[n->id()];
1086 
1087     if (depth++ > 5000) {
1088       return errors::Internal(
1089           "Frame of depth > 5000:  Probably malformed graph or a bug in "
1090           "BuildControlFlowInfo");
1091     }
1092   }
1093 
1094   *frame = cfi_iter->frame_name;
1095   return Status::OK();
1096 }
1097 }  // namespace
1098 
HandleMerge(Node * n,std::vector<bool> * should_revisit,bool use_optimistic_mode)1099 Status DeadnessAnalysisImpl::HandleMerge(Node* n,
1100                                          std::vector<bool>* should_revisit,
1101                                          bool use_optimistic_mode) {
1102   // Merge ignores deadness of its control inputs.  A merge that isn't the
1103   // target of a backedge has is alive iff any of its data inputs are.  The
1104   // liveness of a merge that is the target of a backedge can sometimes be
1105   // represented using a AndRecurrencePredicate.  If neither apply, we represent
1106   // the liveness of the merge symbolically.
1107 
1108   bool has_unvisited_backedge = false;
1109   for (const Edge* e : n->in_edges()) {
1110     if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
1111       has_unvisited_backedge |= !predicate_map_.count(InputEdgeToTensorId(e));
1112     }
1113   }
1114 
1115   auto it = predicate_map_.find(TensorId(n->name(), 0));
1116   if (it == predicate_map_.end()) {
1117     if (has_unvisited_backedge) {
1118       // We're visiting this merge for the first time and it has an unvisited
1119       // backedge.
1120       Predicate* input_data_pred;
1121       if (use_optimistic_mode) {
1122         // In the optimistic mode, we use the first-seen Merge node per
1123         // frame as the representative Merge node.  It is just convenient and
1124         // does not affect the result after pattern-matching into the
1125         // AndRecurrence form.
1126         absl::string_view frame_name = control_flow_info_[n->id()].frame_name;
1127         auto insert_result = frame_to_merge_node_.insert({frame_name, n});
1128         Node* representative = insert_result.first->second;
1129         TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1130             representative, /*output_idx=*/0, /*must_be_true=*/false,
1131             &input_data_pred));
1132       } else {
1133         TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1134             n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
1135       }
1136 
1137       SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
1138                    should_revisit);
1139       return Status::OK();
1140     }
1141 
1142     std::vector<Predicate*> input_preds;
1143     TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
1144 
1145     // We're visiting this merge for the first time and it is an acyclic merge.
1146     Predicate* input_data_pred =
1147         predicate_factory_.MakeOrPredicate(input_preds);
1148     SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
1149                  should_revisit);
1150     return Status::OK();
1151   }
1152 
1153   if (it->second->kind() == Predicate::Kind::kSymbol) {
1154     // Last time we visited this merge we only got a symbolic predicate because
1155     // of an unvisited backedge.  Try to pattern match the predicate expression
1156     // for that backedge (which should be visited now) into an and recurrence
1157     // for the merge node.
1158     const Edge* unique_backedge;
1159     TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge));
1160     if (unique_backedge) {
1161       if (Predicate* step = DeduceStepPredicate(
1162               &predicate_factory_, it->second,
1163               predicate_map_[InputEdgeToTensorId(unique_backedge)])) {
1164         // If the predicate for the backedge is "Sym&X" where "Sym" is the
1165         // predicate for the merge then the merge has predicate {S,&,X} where S
1166         // is the predicate for the merge ignoring the backedge.
1167         std::vector<Predicate*> non_recurrent_inputs;
1168         for (const Edge* e : n->in_edges()) {
1169           if (e != unique_backedge) {
1170             non_recurrent_inputs.push_back(
1171                 predicate_map_[InputEdgeToTensorId(e)]);
1172           }
1173         }
1174 
1175         Predicate* start =
1176             predicate_factory_.MakeOrPredicate(non_recurrent_inputs);
1177         std::vector<string> frame;
1178         TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame));
1179         Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate(
1180             start, step, std::move(frame));
1181         SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
1182         return Status::OK();
1183       }
1184     }
1185   }
1186   return Status::OK();
1187 }
1188 
HandleRecv(Node * n,std::vector<bool> * should_revisit)1189 Status DeadnessAnalysisImpl::HandleRecv(Node* n,
1190                                         std::vector<bool>* should_revisit) {
1191   // In addition to being alive or dead based on the inputs, a _Recv can also
1192   // acquire a dead signal from a _Send.
1193   std::vector<Predicate*> input_preds;
1194   TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
1195   Predicate* signal_is_alive;
1196   TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
1197       n, /*output_idx=*/0, /*must_be_true=*/false, &signal_is_alive));
1198   input_preds.push_back(signal_is_alive);
1199   SetPredicate(n, {0, Graph::kControlSlot},
1200                predicate_factory_.MakeAndPredicate(input_preds),
1201                should_revisit);
1202   return Status::OK();
1203 }
1204 
HandleGeneric(Node * n,std::vector<bool> * should_revisit)1205 Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
1206                                            std::vector<bool>* should_revisit) {
1207   // Generally nodes are alive iff all their inputs are alive.
1208   std::vector<Predicate*> input_preds;
1209   TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
1210   Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds);
1211   for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
1212     SetPredicate(n, output_idx, pred, should_revisit);
1213   }
1214   SetPredicate(n, Graph::kControlSlot, pred, should_revisit);
1215   return Status::OK();
1216 }
1217 
HandleNode(Node * n,std::vector<bool> * should_revisit,bool use_optimistic_mode)1218 Status DeadnessAnalysisImpl::HandleNode(Node* n,
1219                                         std::vector<bool>* should_revisit,
1220                                         bool use_optimistic_mode) {
1221   if (n->IsSwitch()) {
1222     TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
1223   } else if (n->IsMerge()) {
1224     TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode));
1225   } else if (n->IsControlTrigger()) {
1226     SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
1227                  nullptr);
1228   } else if (n->IsRecv() || n->IsHostRecv()) {
1229     TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit));
1230   } else if (n->IsNextIteration()) {
1231     TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1232   } else {
1233     TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1234   }
1235   return Status::OK();
1236 }
1237 
1238 // Compute a special topological order for the Graph, where nodes having the
1239 // same root frame are placed adjacent to each other.  The traversal uses a
1240 // variant of Kahn's algorithm.  num_ready_inputs is used to keep track of how
1241 // many inputs of each node are ready; a node is ready to be scheduled if all
1242 // of its inputs are ready.
1243 // Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details.
GetFrameBasedTopologicalOrder(std::vector<Node * > * order)1244 Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder(
1245     std::vector<Node*>* order) {
1246   absl::flat_hash_map<absl::string_view, size_t> num_enters_for_frame;
1247   absl::flat_hash_map<absl::string_view, size_t> num_exits_for_frame;
1248   std::vector<size_t> num_ready_inputs(graph_.num_node_ids(), 0);
1249   Node* src_node = graph_.source_node();
1250   for (const auto* node : graph_.op_nodes()) {
1251     const ControlFlowInfo& cf = control_flow_info_[node->id()];
1252     if (IsRootEnter(node)) {
1253       // Since we care only the root-level frame, full frame names are the same
1254       // as frame names.
1255       ++num_enters_for_frame[cf.frame_name];
1256     } else if (IsRootExit(node)) {
1257       ++num_exits_for_frame[cf.frame_name];
1258     }
1259     // Edge NextIteration->Merge is counted before starting the traversal to
1260     // break the backedges.
1261     if (IsMerge(node)) {
1262       for (const Edge* e : node->in_edges()) {
1263         if (IsNextIteration(e->src())) {
1264           ++num_ready_inputs[node->id()];
1265         }
1266       }
1267     }
1268   }
1269 
1270   // dequeue is used to ensure that the nodes are first-in-first-out.  This
1271   // order guarantees that the exits in the ready queue are visited before
1272   // nodes that will become ready in the future.
1273   std::deque<Node*> ready;
1274   ready.push_back(src_node);
1275   // ready_enters_per_frame and ready_exits serve as a staging area to buffer
1276   // the ready enters/exits before they are moved to the `ready` queue for
1277   // controlling the start and end of a processing frame.
1278   absl::flat_hash_map<absl::string_view, std::vector<Node*>>
1279       ready_enters_per_frame;
1280   // Exit nodes shall all be from the same frame, as we process a frame at a
1281   // time. So, one vector is enough.
1282   std::vector<Node*> ready_exits;
1283   while (!ready.empty()) {
1284     Node* curr_node = ready.front();
1285     ready.pop_front();
1286 
1287     VLOG(4) << "Visiting " << curr_node->name();
1288     order->push_back(curr_node);
1289 
1290     for (const Edge* out_edge : curr_node->out_edges()) {
1291       Node* out = out_edge->dst();
1292       int out_id = out->id();
1293       if (IsNextIteration(curr_node) && IsMerge(out)) {
1294         // Edge NextIteration->Merge has been counted.
1295         continue;
1296       }
1297       ++num_ready_inputs[out->id()];
1298       if (!out->IsOp()) continue;  // Skip Sink/Source nodes.
1299       if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
1300 
1301       absl::string_view frame_name = control_flow_info_[out_id].frame_name;
1302       if (IsRootEnter(out)) {
1303         ready_enters_per_frame[frame_name].push_back(out);
1304       } else if (IsRootExit(out)) {
1305         ready_exits.push_back(out);
1306       } else {
1307         ready.push_back(out);
1308       }
1309     }
1310 
1311     if (ready.empty()) {
1312       // Try moving nodes from ready_enters_per_frame and ready_exits to
1313       // `ready`.
1314       if (!ready_exits.empty()) {
1315         // If there are nodes in ready_exits we must process them before
1316         // processing ready_enters_per_frame to make sure all nodes in the
1317         // currently processing frame are visited before starting processing
1318         // other frames.
1319         absl::string_view frame_name =
1320             control_flow_info_[ready_exits.front()->id()].frame_name;
1321         CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]);
1322         ready.insert(ready.end(), ready_exits.begin(), ready_exits.end());
1323         ready_exits.clear();
1324       } else {
1325         // Otherwise, try moving nodes from ready_enters to `ready`.
1326         for (auto iter = ready_enters_per_frame.begin();
1327              iter != ready_enters_per_frame.end(); ++iter) {
1328           absl::string_view frame_name = iter->first;
1329           const std::vector<Node*>& ready_enters = iter->second;
1330           if (ready_enters.size() == num_enters_for_frame[frame_name]) {
1331             ready.insert(ready.end(), ready_enters.begin(), ready_enters.end());
1332             ready_enters_per_frame.erase(iter);
1333             break;
1334           }
1335         }
1336       }
1337     }
1338   }
1339 
1340   if (!ready_enters_per_frame.empty() || !ready_exits.empty()) {
1341     return errors::InvalidArgument(
1342         "Some enters/exits have never been visited in the traversal."
1343         " Most probably the input graph is malformed.");
1344   }
1345   return Status::OK();
1346 }
1347 
1348 // We populate the nodes along a special topological order where nodes having
1349 // the same root frame are placed adjacent to each other.  This grouping enables
1350 // processing the graph per root frame at a time and guarantees that when a root
1351 // frame is being processed, nodes in the downstream frames have not yet been
1352 // processed.  This property is important because we need to process an entire
1353 // frame to know whether the optimistic mode converges or not.  In other words,
1354 // nodes in the downstream frames shall not be populated until all of its
1355 // upstream frames are populated.  In effect, this order enables processing each
1356 // (nested) tf.while one-by-one, as each (nested) tf.while creates a unique
1357 // (root) frame.  Note that we don't separate while loops belonging to the same
1358 // nested while, as there is no clean cut for separating them in the topological
1359 // order.
Populate(bool enable_optimistic)1360 Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) {
1361   std::vector<string> unreachable_nodes;
1362   // Compute the loop structure of the graph.
1363   TF_RETURN_IF_ERROR(
1364       BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes));
1365 
1366   // Do some opportunistic error checking:
1367   if (!unreachable_nodes.empty()) {
1368     if (unreachable_nodes.size() > 5) {
1369       unreachable_nodes.erase(unreachable_nodes.begin() + 5,
1370                               unreachable_nodes.end());
1371     }
1372 
1373     return errors::InvalidArgument(
1374         "Found unreachable nodes, most likely source and sink nodes not "
1375         "connected: ",
1376         absl::StrJoin(unreachable_nodes, ", "));
1377   }
1378 
1379   std::vector<Node*> topo;
1380   TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo));
1381 
1382   size_t frame_start = 0;
1383   while (frame_start < topo.size()) {
1384     // Batching nodes who have the same root frame.
1385     absl::string_view cur_frame_name;
1386     TF_RETURN_IF_ERROR(
1387         GetRootFrame(topo[frame_start], control_flow_info_, &cur_frame_name));
1388     size_t frame_end = frame_start;
1389     for (size_t i = frame_start + 1; i < topo.size(); ++i) {
1390       absl::string_view i_frame_name;
1391       TF_RETURN_IF_ERROR(
1392           GetRootFrame(topo[i], control_flow_info_, &i_frame_name));
1393       if (i_frame_name == cur_frame_name) {
1394         frame_end = i;
1395       } else {
1396         break;
1397       }
1398     }
1399     absl::Span<Node*> sub_topo(topo.data() + frame_start,
1400                                /*length=*/frame_end - frame_start + 1);
1401     frame_start = frame_end + 1;
1402 
1403     // First, try the optimistic mode.
1404     bool success = false;
1405     if (enable_optimistic && !cur_frame_name.empty()) {
1406       TF_RETURN_IF_ERROR(
1407           PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
1408     }
1409     if (!success) {
1410       // The optimistic mode does not converge.  Let's fall back to the
1411       // pessimistic mode.
1412       TF_RETURN_IF_ERROR(
1413           PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr));
1414     }
1415     VLOG(2) << "Done populating frame " << cur_frame_name << " using the "
1416             << (success ? "optimistic" : "pessimistic") << " mode.";
1417   }
1418 
1419   return Status::OK();
1420 }
1421 
PopulateFrame(absl::Span<Node * const> topo,bool use_optimistic_mode,bool * success)1422 Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
1423                                            bool use_optimistic_mode,
1424                                            bool* success) {
1425   CHECK(use_optimistic_mode && success != nullptr ||
1426         !use_optimistic_mode && success == nullptr);
1427 
1428   // This an abstract interpretation over the deadness propagation semantics of
1429   // the graph executor.
1430   //
1431   // We iterate over the graph twice, each time in a topological order.  On the
1432   // first iteration merge nodes with backedges are mapped to symbolic
1433   // predicates.  On the second iteration we use the predicates assigned to the
1434   // backedges in the previous iteration to infer a more precise predicate for
1435   // the backedge merge nodes and all the nodes that transitively use it.
1436   //
1437   // We don't track the output indices for should_revisit.  Instead, putting a
1438   // node in `should_revisit` denotes that the deadness flowing out from any
1439   // output from said node may have changed.  This is fine; only switches
1440   // propagate different deadness along different output edges, and since the
1441   // delta is solely due to the input *values* (and not input deadness), the
1442   // delta should not change in the second iteration.
1443   std::vector<bool> should_revisit;
1444   should_revisit.resize(graph_.num_node_ids());
1445   for (Node* n : topo) {
1446     VLOG(4) << "Visiting " << n->name();
1447     TF_RETURN_IF_ERROR(
1448         HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
1449     if (n->IsNextIteration()) {
1450       // If this is a backedge for a merge node then remember to reprocess the
1451       // merge the next time we run.
1452       for (const Edge* e : n->out_edges()) {
1453         if (e->dst()->IsMerge()) {
1454           should_revisit[e->dst()->id()] = true;
1455         }
1456       }
1457     }
1458   }
1459 
1460   for (Node* n : topo) {
1461     // The nodes added to should_revisit in the previous loop need to be
1462     // revisited now.  Reprocessing these initial nodes may add *their*
1463     // consumers to should_revisit, and these newly added nodes will also be
1464     // processed by this very same loop.  Since we're traversing the graph in
1465     // topological order (producers before consumers) and HandleNode(n) can only
1466     // ever add n's consumers to should_revisit, we won't "miss" an addition to
1467     // should_revisit.
1468     if (should_revisit[n->id()]) {
1469       VLOG(4) << "Revisiting " << n->name();
1470       TF_RETURN_IF_ERROR(HandleNode(n, &should_revisit));
1471     }
1472   }
1473 
1474   // Check if the optimistic analysis converges.  Specifically, check whether
1475   // all the predicates of the merge nodes in the same frame are the same.  If
1476   // yes, report success.  If not, report failure and clear the assigned
1477   // predicates.
1478   if (use_optimistic_mode) {
1479     bool is_converged = true;
1480     absl::flat_hash_map<absl::string_view, Predicate*> frame_to_pred;
1481     for (Node* n : topo) {
1482       if (!n->IsMerge()) {
1483         continue;
1484       }
1485       const Edge* e;
1486       TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e));
1487       if (e == nullptr) {
1488         // Skip acyclic merge nodes.
1489         continue;
1490       }
1491       Node* merge = n;
1492       // Note that here uses frame names instead of root frame names.  In the
1493       // case of a nested while loop, each level of while loops can have merges
1494       // with different predicate instances, while the merge nodes on the same
1495       // level must have the same predicate instances.
1496       absl::string_view frame_name = control_flow_info_[merge->id()].frame_name;
1497       auto it = predicate_map_.find(TensorId(merge->name(), 0));
1498       Predicate* merge_pred = it->second;
1499       if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
1500         is_converged = false;
1501         VLOG(2) << "Running the optimistic mode on frame " << frame_name
1502                 << " does not converge because node " << merge->name()
1503                 << " cannot be mapped into the AndRecurrence form.";
1504         break;
1505       }
1506 
1507       auto insert_result = frame_to_pred.insert({frame_name, merge_pred});
1508       if (!insert_result.second) {
1509         // If we have already seen this frame name, verify the predicate is the
1510         // same as the previously seen one's.
1511         Predicate* curr_andrec = merge_pred;
1512         Predicate* prev_andrec = insert_result.first->second;
1513         if (curr_andrec != prev_andrec) {
1514           is_converged = false;
1515           VLOG(2) << "Running the optimistic mode on frame " << frame_name
1516                   << " does not converge. Seeing different Merge predicates: \n"
1517                   << curr_andrec->ToString() << " and \n"
1518                   << prev_andrec->ToString();
1519           break;
1520         }
1521       }
1522     }
1523 
1524     // Clear the assigned predicates if the optimistic mode does not converge.
1525     if (!is_converged) {
1526       for (Node* n : topo) {
1527         for (int oid = 0; oid < n->num_outputs(); ++oid) {
1528           predicate_map_.erase(TensorId(n->name(), oid));
1529         }
1530         predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot));
1531       }
1532     }
1533 
1534     if (success != nullptr) {
1535       *success = is_converged;
1536     }
1537   }
1538 
1539   return Status::OK();
1540 }
1541 
1542 StatusOr<DeadnessAnalysis::DeadnessPredicate>
GetPredicateFor(Node * n,int oidx) const1543 DeadnessAnalysisImpl::GetPredicateFor(Node* n, int oidx) const {
1544   auto it = predicate_map_.find(TensorId(n->name(), oidx));
1545   TF_RET_CHECK(it != predicate_map_.end())
1546       << "could not find " << TensorId(n->name(), oidx).ToString()
1547       << " in predicate map";
1548   return MakeDeadnessPredicate(it->second);
1549 }
1550 
Print() const1551 void DeadnessAnalysisImpl::Print() const {
1552   std::vector<TensorId> tensor_ids;
1553   for (const auto& kv_pair : predicate_map_) {
1554     tensor_ids.push_back(kv_pair.first);
1555   }
1556 
1557   std::sort(tensor_ids.begin(), tensor_ids.end());
1558 
1559   for (TensorId tensor_id : tensor_ids) {
1560     auto it = predicate_map_.find(tensor_id);
1561     CHECK(it != predicate_map_.end()) << tensor_id.ToString();
1562     VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString();
1563   }
1564 }
1565 
1566 }  // namespace
1567 
~DeadnessAnalysis()1568 DeadnessAnalysis::~DeadnessAnalysis() {}
1569 
Run(const Graph & graph,std::unique_ptr<DeadnessAnalysis> * result)1570 /*static*/ Status DeadnessAnalysis::Run(
1571     const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
1572   std::unique_ptr<DeadnessAnalysisImpl> analysis(
1573       new DeadnessAnalysisImpl(&graph));
1574   TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true));
1575 
1576   if (VLOG_IS_ON(2)) {
1577     analysis->Print();
1578   }
1579 
1580   *result = std::move(analysis);
1581   return Status::OK();
1582 }
1583 
1584 absl::flat_hash_map<TensorId, string, TensorId::Hasher>
PredicateMapAsString() const1585 DeadnessAnalysisImpl::PredicateMapAsString() const {
1586   absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
1587   for (const auto& kv_pair : predicate_map_) {
1588     CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
1589   }
1590   return result;
1591 }
1592 
1593 namespace deadness_analysis_internal {
ComputePredicates(const Graph & graph,PredicateMapTy * out_predicate_map,bool enable_optimistic)1594 Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
1595                          bool enable_optimistic) {
1596   DeadnessAnalysisImpl impl(&graph);
1597   TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic));
1598   *out_predicate_map = impl.PredicateMapAsString();
1599   return Status::OK();
1600 }
1601 
1602 }  // namespace deadness_analysis_internal
1603 
DebugString(DeadnessPredicate predicate) const1604 string DeadnessAnalysis::DebugString(DeadnessPredicate predicate) const {
1605   return static_cast<Predicate*>(predicate.pred_)->ToString();
1606 }
1607 
1608 }  // namespace tensorflow
1609