1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/deadness_analysis.h"
17 #include "absl/algorithm/container.h"
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
22 #include "tensorflow/compiler/jit/xla_cluster_util.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/graph/algorithm.h"
26 #include "tensorflow/core/graph/control_flow.h"
27 #include "tensorflow/core/graph/tensor_id.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 
30 // ALGORITHM OVERVIEW
31 // ==================
32 //
33 // We map every output produced by each node in the TensorFlow graph (including
34 // control dependence) into an instance of the Predicate class.  Instances of
35 // Predicate denote logical formulas and mapping a node `n` to a predicate
36 // `pred` implies that `n` is live whenever `pred` is true.  Then we can deduce
37 // mismatching liveness in the inputs to node by comparing the predicate those
38 // inputs are mapped to.  The core logic of this pass resides in creating the
39 // map from TensorFlow nodes to predicates.
40 //
41 //
42 // MAPPING NODES TO PREDICATES, MODULO CYCLES
43 // ------------------------------------------
44 //
45 // If we ignore cycles for a moment, computing predicates is fairly
46 // straightforward.  We traverse the graph in RPO, mapping each node to a
47 // predicate based on the predicates its inputs are mapped to.  For instance a
48 // Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)).
49 // Roughtly speaking, we abstract interpret each node on the "liveness" domain,
50 // where values in the domain represent if a tensor carries a dead signal or
51 // not.
52 //
53 //
54 // DEALING WITH CYCLES
55 // -------------------
56 //
57 // We map Merge nodes that are the target of a backedge to AndRecurrence
58 // instances.  An AndRecurrence with start() = S and step() = X, printed as
59 // {S,&,X}, *roughly* represents the infinite list of predicates
60 // [S,S&X,S&X&X,S&X&X, ...].  So {S,&,X} can be used to represent the predicate
61 // for Merge in a graph like:
62 //
63 //     Init
64 //       |
65 //       v
66 //     Merge <-----------+
67 //       |               |
68 //       v               |
69 //      Incr             |
70 //       |               |
71 //       v               |
72 //      Switch <- Cond   |
73 //       |               |
74 //       v (oidx: 1)     |
75 //       |               |
76 //       +---------------+
77 //
78 // Where S is the predicate for Init and X is the predicate that asserts that
79 // Cond is true.  {S,&,X} states that Merge is live on the first "iteration" iff
80 // S is true, live on the second iteration iff "S&X" is true, live on the third
81 // iteration iff "S&X&X" is true etc.  There is a subtlety here, S&X&X would
82 // normally be equivalent to S&X which isn't quite what we want to represent.
83 // Instead we want {S,&,X} to denote the infinite list [S, S&X,
84 // S&X&X',S&X&X'&X'', ...] where X, X', X'' are predicates that assert Cond is
85 // true on iteration 0, 1, 2 respectively.  This is made more precise in the
86 // comment on the AndRecurrence class.
87 //
88 // The general algorithm that deals with cycles does two RPO (reverse post
89 // order) passes over the graph.  On the first pass it assigns a symbolic
90 // predicate to merge nodes with backedges.  On the second pass it tries to
91 // pattern matche the predicates for the backedges of these merges and infer an
92 // AndRecurrence for the merge.
93 //
94 // In other words, we do a pessimistic data flow analysis where the data-flow
95 // lattice has two elements, Symbolic and NonSymbolic with Symbolic >
96 // NonSymbolic. The lattice has height = 2 so two iterations are sufficient to
97 // converge.  We don't do an optimistic data flow analysis to make pattern
98 // matching easier: if we assigned the predicate of the initial value to the
99 // merge during the first pass, on the second pass the backedge may see a
100 // simplified value that would be difficult to pattern match.
101 //
102 // We still use symbolic predicates for merges for which we can't pattern match
103 // on the backedge predicate.  This is conservatively correct.
104 
105 namespace tensorflow {
106 
107 namespace {
108 
109 // Represents a logical predicate, used as described in the algorithm overview
110 // above.
111 class Predicate {
112  public:
113   enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol };
114 
115   virtual string ToString() const = 0;
116 
117   // An ID assigned to the Predicate at construction time.  Conceptually like a
118   // pointer, except that it is stable across runs.
id() const119   int64 id() const { return id_; }
120 
121   virtual absl::Span<Predicate* const> GetOperands() const = 0;
122 
123   virtual Kind kind() const = 0;
~Predicate()124   virtual ~Predicate() {}
125 
126   // Invokes func on p and on all of its operands recursively.  Does not invoke
127   // `func` on the same Predicate instance twice.  Aborts the search if `func`
128   // returns true.
129   template <typename FunctionTy>
130   static void Visit(Predicate* p, const FunctionTy& func);
131 
132  protected:
Predicate(int64 id)133   explicit Predicate(int64 id) : id_(id) {}
134 
135  private:
136   const int64 id_;
137 
138   TF_DISALLOW_COPY_AND_ASSIGN(Predicate);
139 };
140 
141 // Represents a logical conjunction of a set of predicates.
142 class AndPredicate : public Predicate {
143  public:
AndPredicate(int64 id,std::vector<Predicate * > operands)144   explicit AndPredicate(int64 id, std::vector<Predicate*> operands)
145       : Predicate(id), operands_(std::move(operands)) {}
146 
ToString() const147   string ToString() const override {
148     if (operands().empty()) {
149       return "#true";
150     }
151 
152     std::vector<string> operands_str;
153     std::transform(operands().begin(), operands().end(),
154                    std::back_inserter(operands_str),
155                    [](Predicate* pred) { return pred->ToString(); });
156 
157     return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
158   }
159 
kind() const160   Kind kind() const override { return Kind::kAnd; }
161 
GetOperands() const162   absl::Span<Predicate* const> GetOperands() const override {
163     return operands_;
164   }
operands() const165   absl::Span<Predicate* const> operands() const { return operands_; }
166 
167  private:
168   std::vector<Predicate*> operands_;
169 };
170 
171 // Represents a logical disjunction of a set of predicates.
172 class OrPredicate : public Predicate {
173  public:
OrPredicate(int64 id,std::vector<Predicate * > operands)174   explicit OrPredicate(int64 id, std::vector<Predicate*> operands)
175       : Predicate(id), operands_(std::move(operands)) {}
176 
ToString() const177   string ToString() const override {
178     if (operands().empty()) {
179       return "#false";
180     }
181 
182     std::vector<string> operands_str;
183     std::transform(operands().begin(), operands().end(),
184                    std::back_inserter(operands_str),
185                    [](Predicate* pred) { return pred->ToString(); });
186 
187     return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
188   }
189 
kind() const190   Kind kind() const override { return Kind::kOr; }
GetOperands() const191   absl::Span<Predicate* const> GetOperands() const override {
192     return operands_;
193   }
operands() const194   absl::Span<Predicate* const> operands() const { return operands_; }
195 
196  private:
197   std::vector<Predicate*> operands_;
198 };
199 
200 // Represents a logical negation of a set of predicates.
201 class NotPredicate : public Predicate {
202  public:
NotPredicate(int64 id,Predicate * operand)203   explicit NotPredicate(int64 id, Predicate* operand)
204       : Predicate(id), operands_({operand}) {}
205 
ToString() const206   string ToString() const override {
207     return absl::StrCat("~", operand()->ToString());
208   }
209 
kind() const210   Kind kind() const override { return Kind::kNot; }
operand() const211   Predicate* operand() const { return operands_[0]; }
GetOperands() const212   absl::Span<Predicate* const> GetOperands() const override {
213     return operands_;
214   }
215 
216  private:
217   std::array<Predicate*, 1> operands_;
218 };
219 
220 // Represents the liveness of an induction variable.  For users inside the loop
221 // this represents the "current" liveness of the induction variable.  For users
222 // outside the loop it represents the "last" liveness of the induction variable.
223 //
224 // More concretely, an and recurrence {S,&,X}<loop> represents the liveness of V
225 // in the following graph:
226 //
227 //   V = Merge(S', V_NextIt)
228 //   V = Op(V, X')
229 //   V_NextIt = NextIteration(V)
230 //
231 // where Predicate(S') = S and Predicate(X') = X.
232 //
233 // `X` may contain symbolic predicates and the operations corresponding to these
234 // symbolic predicates are either in frame `loop` or outside it.  The symbols
235 // that are inside frame `loop` are loop variant (i.e. can have different
236 // liveness in each loop iteration) and the symbols that are outside frame
237 // `loop` are loop invariant (i.e. have the same liveness across all
238 // iterations).
239 class AndRecurrencePredicate : public Predicate {
240  public:
AndRecurrencePredicate(int64 id,Predicate * start,Predicate * step,std::vector<string> frame)241   explicit AndRecurrencePredicate(int64 id, Predicate* start, Predicate* step,
242                                   std::vector<string> frame)
243       : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {}
244 
start() const245   Predicate* start() const { return operands_[0]; }
step() const246   Predicate* step() const { return operands_[1]; }
frame() const247   absl::Span<const string> frame() const { return frame_; }
248 
ToString() const249   string ToString() const override {
250     return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
251                         "}<", absl::StrJoin(frame(), ";"), ">");
252   }
253 
kind() const254   Kind kind() const override { return Kind::kAndRecurrence; }
255 
GetOperands() const256   absl::Span<Predicate* const> GetOperands() const override {
257     return operands_;
258   }
259 
260  private:
261   std::array<Predicate*, 2> operands_;
262   std::vector<string> frame_;
263 };
264 
265 // Represents an uninterpreted symbol in a logical predicate.
266 //
267 // Two predicates are equivalent iff they are equivalent for all assignments to
268 // the symbols contained in them, i.e. predicates are forall qualified over
269 // symbols.
270 class SymbolPredicate : public Predicate {
271  public:
SymbolPredicate(int64 id,TensorId tensor_id,bool must_be_true)272   explicit SymbolPredicate(int64 id, TensorId tensor_id, bool must_be_true)
273       : Predicate(id),
274         tensor_id_(std::move(tensor_id)),
275         must_be_true_(must_be_true) {}
276 
ToString() const277   string ToString() const override {
278     return must_be_true() ? absl::StrCat("*", tensor_id_.ToString())
279                           : tensor_id_.ToString();
280   }
281 
kind() const282   Kind kind() const override { return Kind::kSymbol; }
GetOperands() const283   absl::Span<Predicate* const> GetOperands() const override { return {}; }
284 
285   // If `must_be_true()` is true this SymbolPredicate represents the proposition
286   // "tensor_id() is live and evaluates to true".
287   //
288   // If `must_be_true()` is false then this SymbolPredicate represents the
289   // proposition "tensor_id() is live (and may evaluate to any value)"
tensor_id() const290   TensorId tensor_id() const { return tensor_id_; }
must_be_true() const291   bool must_be_true() const { return must_be_true_; }
292 
293  private:
294   TensorId tensor_id_;
295   bool must_be_true_;
296 };
297 
298 template <typename FunctionTy>
Visit(Predicate * p,const FunctionTy & func)299 /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
300   absl::flat_hash_set<Predicate*> visited;
301   std::vector<Predicate*> stack;
302 
303   stack.push_back(p);
304   visited.insert(p);
305 
306   while (!stack.empty()) {
307     Predicate* current = stack.back();
308     stack.pop_back();
309     bool done = func(current);
310     if (done) {
311       return;
312     }
313     for (Predicate* op : current->GetOperands()) {
314       if (visited.insert(op).second) {
315         stack.push_back(op);
316       }
317     }
318   }
319 }
320 
321 // Creates and owns Predicate instances.  Simplifies predicates as it creates
322 // them.
323 class PredicateFactory {
324  public:
MakeAndPredicate(absl::Span<Predicate * const> operands)325   Predicate* MakeAndPredicate(absl::Span<Predicate* const> operands) {
326     return MakeAndOrImpl(operands, /*is_and=*/true);
327   }
328 
MakeOrPredicate(absl::Span<Predicate * const> operands)329   Predicate* MakeOrPredicate(absl::Span<Predicate* const> operands) {
330     return MakeAndOrImpl(operands, /*is_and=*/false);
331   }
332 
MakeNotPredicate(Predicate * pred)333   Predicate* MakeNotPredicate(Predicate* pred) {
334     auto it = make_not_predicate_cache_.find(pred);
335     if (it != make_not_predicate_cache_.end()) {
336       return it->second;
337     }
338 
339     Predicate* result = MakeNotPredicateImpl(pred);
340 
341     bool insert_successful =
342         make_not_predicate_cache_.insert({pred, result}).second;
343     (void)insert_successful;
344     DCHECK(insert_successful);
345 
346     return result;
347   }
348 
MakeAndRecurrencePredicate(Predicate * start,Predicate * step,std::vector<string> frame)349   Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step,
350                                         std::vector<string> frame) {
351     SignatureForAndRec signature(start, step, std::move(frame));
352     auto it = interned_and_rec_instances_.find(signature);
353     if (it != interned_and_rec_instances_.end()) {
354       return it->second.get();
355     }
356 
357     std::unique_ptr<Predicate> new_pred = Make<AndRecurrencePredicate>(
358         std::get<0>(signature), std::get<1>(signature), std::get<2>(signature));
359     Predicate* new_pred_ptr = new_pred.get();
360     bool inserted =
361         interned_and_rec_instances_.emplace(signature, std::move(new_pred))
362             .second;
363     (void)inserted;
364     DCHECK(inserted);
365     return new_pred_ptr;
366   }
367 
MakeSymbolPredicate(Node * node,int output_idx,bool must_be_true,Predicate ** predicate)368   Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true,
369                              Predicate** predicate) {
370     TensorId tensor_id(node->name(), output_idx);
371 
372     bool is_boolean_tensor = node->output_type(tensor_id.index()) == DT_BOOL;
373     TF_RET_CHECK(!must_be_true || is_boolean_tensor);
374 
375     if (node->type_string() == "Const" && must_be_true) {
376       const TensorProto* proto = nullptr;
377       TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "value", &proto));
378 
379       Tensor tensor(proto->dtype());
380       TF_RET_CHECK(tensor.FromProto(*proto));
381 
382       *predicate = tensor.scalar<bool>()() ? MakeTrue() : MakeFalse();
383       return Status::OK();
384     }
385 
386     SignatureForSymbol signature = {tensor_id, must_be_true};
387     auto it = interned_symbol_instances_.find(signature);
388     if (it == interned_symbol_instances_.end()) {
389       std::unique_ptr<Predicate> new_pred =
390           Make<SymbolPredicate>(tensor_id, must_be_true);
391       Predicate* new_pred_ptr = new_pred.get();
392       interned_symbol_instances_.emplace(std::move(signature),
393                                          std::move(new_pred));
394       *predicate = new_pred_ptr;
395     } else {
396       *predicate = it->second.get();
397     }
398 
399     return Status::OK();
400   }
401 
MakeTrue()402   Predicate* MakeTrue() { return MakeAndPredicate({}); }
MakeFalse()403   Predicate* MakeFalse() { return MakeOrPredicate({}); }
404 
~PredicateFactory()405   ~PredicateFactory() {
406     DCHECK_EQ(stack_depth_, 0) << "Unnested IncrementStackDepth?";
407   }
408 
409  private:
MakeNotPredicateImpl(Predicate * pred)410   Predicate* MakeNotPredicateImpl(Predicate* pred) {
411     IncrementStackDepth stack_frame(this);
412     if (!stack_frame.HasOverflowed()) {
413       if (Predicate* simplified = SimplifyUsingDeMorgan(pred)) {
414         return simplified;
415       }
416 
417       // ~~A => A
418       if (auto* not_pred = dynamic_cast<NotPredicate*>(pred)) {
419         return not_pred->operand();
420       }
421     }
422 
423     SignatureForNot signature = pred;
424     auto it = interned_not_instances_.find(signature);
425     if (it == interned_not_instances_.end()) {
426       std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
427       Predicate* new_pred_ptr = new_pred.get();
428       interned_not_instances_.emplace(signature, std::move(new_pred));
429       return new_pred_ptr;
430     } else {
431       return it->second.get();
432     }
433   }
434 
SimplifyUsingDeMorgan(Predicate * pred)435   Predicate* SimplifyUsingDeMorgan(Predicate* pred) {
436     // ~(A & B & C & ...) => ~A | ~B | ~C | ~...
437     // ~(A | B | C | ...) -> ~A & ~B & ~C & ~...
438     Predicate::Kind kind = pred->kind();
439 
440     if (kind == Predicate::Kind::kAnd || kind == Predicate::Kind::kOr) {
441       std::vector<Predicate*> new_operands;
442       absl::c_transform(pred->GetOperands(), std::back_inserter(new_operands),
443                         [&](Predicate* p) { return MakeNotPredicate(p); });
444       return kind == Predicate::Kind::kOr ? MakeAndPredicate(new_operands)
445                                           : MakeOrPredicate(new_operands);
446     }
447 
448     return nullptr;
449   }
450 
451   template <typename PredicateT, typename... Args>
Make(Args &&...args)452   std::unique_ptr<Predicate> Make(Args&&... args) {
453     // If we ever expose the Predicate class outside this .cc file then we may
454     // want to make this hard to misuse (by accidentally passing in an arbitrary
455     // integer to the Predicate constructor for instance).
456     return std::unique_ptr<PredicateT>(
457         new PredicateT(id_counter_++, std::forward<Args>(args)...));
458   }
459 
460   Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
461   Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops,
462                                Predicate::Kind pred_kind);
463 
464   // Predicate instances are interned, meaning that there is only a single
465   // instance of a Predicate object with a given content.  This makes checking
466   // for structural equality super-cheap -- we can just compare pointers.
467   //
468   // We intern predicates by maintaining a map from the content of a Predicate
469   // to the only instance of said predicate we allow to exist in the
470   // interned_and_or_instances_, interned_not_instances_ and
471   // interned_symbol_instances_ fields.  These maps also double up as storage
472   // for the owning pointers to predicate instances.
473 
474   using SignatureForAndOr =
475       std::pair<Predicate::Kind, absl::Span<Predicate* const>>;
476   using SignatureForNot = Predicate*;
477   using SignatureForAndRec =
478       std::tuple<Predicate*, Predicate*, std::vector<string>>;
479   using SignatureForSymbol = std::pair<SafeTensorId, bool>;
480 
481   struct HashSignatureForAndOr {
operator ()tensorflow::__anonab73ef130111::PredicateFactory::HashSignatureForAndOr482     size_t operator()(const SignatureForAndOr& signature) const {
483       size_t hash = ::tensorflow::hash<Predicate::Kind>()(signature.first);
484       for (Predicate* p : signature.second) {
485         hash = Hash64Combine(hash, ::tensorflow::hash<Predicate*>()(p));
486       }
487       return hash;
488     }
489   };
490 
491   struct HashSignatureForSymbol {
operator ()tensorflow::__anonab73ef130111::PredicateFactory::HashSignatureForSymbol492     size_t operator()(const SignatureForSymbol& signature) const {
493       return Hash64Combine(SafeTensorId::Hasher()(signature.first),
494                            ::tensorflow::hash<bool>()(signature.second));
495     }
496   };
497 
498   // Used to limit recursion to avoid blowing up the stack and cap compile time.
499   class IncrementStackDepth {
500    public:
IncrementStackDepth(PredicateFactory * parent)501     explicit IncrementStackDepth(PredicateFactory* parent) : parent_(parent) {
502       parent_->stack_depth_++;
503     }
504 
HasOverflowed() const505     bool HasOverflowed() const {
506       const int kMaxStackDepth = 8;
507       return parent_->stack_depth_ >= kMaxStackDepth;
508     }
509 
~IncrementStackDepth()510     ~IncrementStackDepth() { parent_->stack_depth_--; }
511 
512    private:
513     PredicateFactory* parent_;
514   };
515 
516   // A cache for the MakeNotPredicate function.
517   //
518   // NB! This is *not* the same as `interned_not_instances_`.
519   // `interned_not_instances_` maps ensures pointer identity for `NotPredicate`
520   // instances, i.e., it ensures there at most one instance of Not(predicate)
521   // for any given predicate whereas `make_not_predicate_cache_` simply caches
522   // the result of the `MakeNotPredicate` function.  The values in
523   // `interned_not_instances_` are always instance of `NotPredicate` whereas the
524   // values in `make_not_predicate_cache_` may not be (for instance it will map
525   // Not(Not(A)) to A).
526   absl::flat_hash_map<Predicate*, Predicate*> make_not_predicate_cache_;
527 
528   absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
529                       HashSignatureForAndOr>
530       interned_and_or_instances_;
531   absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>>
532       interned_not_instances_;
533   absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>>
534       interned_and_rec_instances_;
535   absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
536                       HashSignatureForSymbol>
537       interned_symbol_instances_;
538   int64 id_counter_ = 0;
539   int stack_depth_ = 0;
540 };
541 
MakeInternedAndOr(std::vector<Predicate * > simplified_ops,Predicate::Kind pred_kind)542 Predicate* PredicateFactory::MakeInternedAndOr(
543     std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
544   std::stable_sort(
545       simplified_ops.begin(), simplified_ops.end(),
546       [](Predicate* a, Predicate* b) { return a->id() < b->id(); });
547 
548   auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
549   if (it != interned_and_or_instances_.end()) {
550     return it->second.get();
551   }
552 
553   simplified_ops.shrink_to_fit();
554   // NB!  Because we'll use a non-owning reference to simplified_ops in the
555   // key for interned_and_or_instances_ we need to be careful to std::move()
556   // it all the way through.
557   absl::Span<Predicate* const> operands_slice = simplified_ops;
558   std::unique_ptr<Predicate> new_pred =
559       pred_kind == Predicate::Kind::kAnd
560           ? Make<AndPredicate>(std::move(simplified_ops))
561           : Make<OrPredicate>(std::move(simplified_ops));
562 
563   Predicate* new_pred_ptr = new_pred.get();
564   interned_and_or_instances_.emplace(
565       SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred));
566   return new_pred_ptr;
567 }
568 
569 // Common code to create AndPredicate or OrPredicate instances.
MakeAndOrImpl(absl::Span<Predicate * const> operands,bool is_and)570 Predicate* PredicateFactory::MakeAndOrImpl(
571     absl::Span<Predicate* const> operands, bool is_and) {
572   Predicate::Kind pred_kind =
573       is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
574 
575   IncrementStackDepth stack_frame(this);
576   if (stack_frame.HasOverflowed()) {
577     return MakeInternedAndOr(
578         std::vector<Predicate*>(operands.begin(), operands.end()), pred_kind);
579   }
580 
581   Predicate::Kind other_pred_kind =
582       is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
583   absl::flat_hash_set<Predicate*> simplified_ops_set;
584   std::vector<Predicate*> simplified_ops;
585   for (Predicate* op : operands) {
586     // Simplify A&A => A and  A|A => A.
587     if (!simplified_ops_set.insert(op).second) {
588       continue;
589     }
590 
591     if (op->kind() == pred_kind) {
592       // "Inline" the operands of an inner And/Or into the parent And/Or.
593       for (Predicate* subop : op->GetOperands()) {
594         if (simplified_ops_set.insert(subop).second) {
595           simplified_ops.push_back(subop);
596         }
597       }
598     } else {
599       simplified_ops.push_back(op);
600     }
601   }
602 
603   if (simplified_ops.size() == 1) {
604     return simplified_ops[0];
605   }
606 
607   // Simplify "A&~A=>False" and "A|~A=>True".
608   absl::flat_hash_set<Predicate*> negated_ops;
609   for (Predicate* op : simplified_ops) {
610     if (negated_ops.count(op)) {
611       // Simple case:
612       //
613       //   A & ~A & ... == False
614       //   A | ~A | ... == True
615       return is_and ? MakeFalse() : MakeTrue();
616     }
617 
618     Predicate* negated_op = MakeNotPredicate(op);
619     if (negated_op->kind() == pred_kind) {
620       // Slightly more complicated case:
621       //
622       //   (~A | ~B | ~C) & A & B & C & ... ==
623       //   ~(A & B & C) & (A & B & C) & ... == False
624       //
625       //   (~A & ~B & ~C) | A | B | C | ... ==
626       //   ~(A | B | C) | (A | B | C) | ... == True
627       if (absl::c_all_of(negated_op->GetOperands(), [&](Predicate* p) {
628             return simplified_ops_set.contains(p);
629           })) {
630         return is_and ? MakeFalse() : MakeTrue();
631       }
632     }
633     negated_ops.insert(negated_op);
634   }
635 
636   // If all ops contain the same subop, then factor it out thanks to the
637   // distributive property. Such as:
638   // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
639   // - (A | B) & (A | C) & (A | D) => A | (B & C & D)
640   //
641   // First find any predicates contained in all subops.
642   std::vector<Predicate*> common_inner_operands;
643   absl::flat_hash_set<Predicate*> common_inner_operands_set;
644   for (Predicate* op : simplified_ops) {
645     if (op->kind() != other_pred_kind) {
646       common_inner_operands.clear();
647       break;
648     }
649 
650     if (common_inner_operands.empty()) {
651       common_inner_operands.insert(common_inner_operands.end(),
652                                    op->GetOperands().begin(),
653                                    op->GetOperands().end());
654     } else {
655       common_inner_operands.clear();
656       absl::c_copy_if(op->GetOperands(),
657                       std::back_inserter(common_inner_operands),
658                       [&](Predicate* sub_op) {
659                         return common_inner_operands_set.count(sub_op) == 1;
660                       });
661     }
662     if (common_inner_operands.empty()) break;
663     common_inner_operands_set.clear();
664     common_inner_operands_set.insert(common_inner_operands.begin(),
665                                      common_inner_operands.end());
666   }
667 
668   if (common_inner_operands.empty()) {
669     return MakeInternedAndOr(std::move(simplified_ops), pred_kind);
670   }
671 
672   // For all predicates that can be factored out, remove them and recreate the
673   // subops.
674   std::vector<Predicate*> factored_ops;
675   for (Predicate* op : simplified_ops) {
676     std::vector<Predicate*> new_sub_op_ops;
677     absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops),
678                     [&](Predicate* sub_op) {
679                       return std::find(common_inner_operands.begin(),
680                                        common_inner_operands.end(),
681                                        sub_op) == common_inner_operands.end();
682                     });
683     factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and));
684   }
685 
686   Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and);
687   std::vector<Predicate*> outer_ops;
688   outer_ops.push_back(new_inner_op);
689   outer_ops.insert(outer_ops.end(), common_inner_operands.begin(),
690                    common_inner_operands.end());
691   return MakeAndOrImpl(outer_ops, !is_and);
692 }
693 
694 class DeadnessAnalysisImpl : public DeadnessAnalysis {
695  public:
DeadnessAnalysisImpl(const Graph * graph)696   explicit DeadnessAnalysisImpl(const Graph* graph)
697       : graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
698 
699   Status Populate();
700   Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
701   bool HasInputsWithMismatchingDeadness(const Node& node) override;
702   void Print() const override;
703   absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
704       const;
705 
706  private:
707   enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
708 
709   Status GetInputPreds(Node* n, EdgeKind edge_kind,
710                        std::vector<Predicate*>* result);
711 
712   // Sets the predicate for output `output_idx` of `n` to `pred`.  Sets the i'th
713   // bit of `should_revisit` if `pred` is different from the current predicate
714   // for the `output_idx` output of `n`.
SetPredicate(Node * n,int output_idx,Predicate * pred,std::vector<bool> * should_revisit)715   void SetPredicate(Node* n, int output_idx, Predicate* pred,
716                     std::vector<bool>* should_revisit) {
717     auto insert_result =
718         predicate_map_.insert({TensorId(n->name(), output_idx), pred});
719     if (!insert_result.second && insert_result.first->second != pred) {
720       VLOG(4) << "For " << n->name() << ":" << output_idx << " from "
721               << insert_result.first->second->ToString() << " "
722               << insert_result.first->second << " to " << pred->ToString()
723               << " " << pred;
724       insert_result.first->second = pred;
725       if (should_revisit != nullptr) {
726         for (const Edge* e : n->out_edges()) {
727           (*should_revisit)[e->dst()->id()] = true;
728         }
729       }
730     }
731   }
732 
SetPredicate(Node * n,absl::Span<const int> output_idxs,Predicate * pred,std::vector<bool> * should_revisit)733   void SetPredicate(Node* n, absl::Span<const int> output_idxs, Predicate* pred,
734                     std::vector<bool>* should_revisit) {
735     for (int output_idx : output_idxs) {
736       SetPredicate(n, output_idx, pred, should_revisit);
737     }
738   }
739 
740   Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
741   Status HandleMerge(Node* n, std::vector<bool>* should_revisit);
742   Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
743   Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
744   Status HandleNode(Node* n, std::vector<bool>* should_revisit);
745 
746   const Graph& graph_;
747   absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
748   PredicateFactory predicate_factory_;
749   std::vector<ControlFlowInfo> control_flow_info_;
750   bool vlog_;
751 };
752 
InputEdgeToTensorId(const Edge * e)753 TensorId InputEdgeToTensorId(const Edge* e) {
754   return TensorId(e->src()->name(), e->src_output());
755 }
756 
GetInputPreds(Node * n,DeadnessAnalysisImpl::EdgeKind edge_kind,std::vector<Predicate * > * result)757 Status DeadnessAnalysisImpl::GetInputPreds(
758     Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind,
759     std::vector<Predicate*>* result) {
760   result->clear();
761   for (const Edge* in_edge : n->in_edges()) {
762     bool should_process =
763         edge_kind == EdgeKind::kDataAndControl ||
764         (in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) ||
765         (!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly);
766 
767     if (should_process) {
768       auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
769       if (it == predicate_map_.end()) {
770         GraphCycles graph_cycles;
771         TF_RETURN_IF_ERROR(
772             CreateCycleDetectionGraph(&graph_, &graph_cycles).status());
773 
774         // If we didn't return with an error above then the graph is probably
775         // fine and we have a bug in deadness analysis.
776         return errors::Internal("Could not find input ", in_edge->DebugString(),
777                                 " to ", n->name(),
778                                 " when visiting the graph in post-order.  Most "
779                                 "likely indicates a bug in deadness analysis.");
780       }
781       result->push_back(it->second);
782     }
783   }
784   return Status::OK();
785 }
786 
HandleSwitch(Node * n,std::vector<bool> * should_revisit)787 Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
788                                           std::vector<bool>* should_revisit) {
789   std::vector<Predicate*> input_preds;
790   TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
791   const Edge* pred_edge;
792   TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
793 
794   Predicate* true_switch;
795   TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
796       pred_edge->src(), pred_edge->src_output(),
797       /*must_be_true=*/true, &true_switch));
798 
799   Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
800 
801   // Output 0 is alive iff all inputs are alive and the condition is false.
802   input_preds.push_back(false_switch);
803   SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds),
804                should_revisit);
805   input_preds.pop_back();
806 
807   // Output 1 is alive iff all inputs are alive and the condition is true.
808   input_preds.push_back(true_switch);
809   SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds),
810                should_revisit);
811   input_preds.pop_back();
812 
813   // Control is alive iff all inputs are alive.
814   SetPredicate(n, Graph::kControlSlot,
815                predicate_factory_.MakeAndPredicate(input_preds),
816                should_revisit);
817 
818   return Status::OK();
819 }
820 
821 namespace {
CreateMultipleNextIterationInputsError(Node * merge)822 Status CreateMultipleNextIterationInputsError(Node* merge) {
823   std::vector<string> backedges;
824   for (const Edge* backedge : merge->in_edges()) {
825     if (backedge->src()->IsNextIteration()) {
826       backedges.push_back(absl::StrCat("  ", SummarizeNode(*backedge->src())));
827     }
828   }
829   return errors::InvalidArgument(
830       "Multiple NextIteration inputs to merge node ",
831       FormatNodeForError(*merge), ": \n", absl::StrJoin(backedges, "\n"),
832       "\nMerge nodes can have at most one incoming NextIteration edge.");
833 }
834 
FindUniqueBackedge(Node * merge,const Edge ** result)835 Status FindUniqueBackedge(Node* merge, const Edge** result) {
836   *result = nullptr;
837   CHECK(merge->IsMerge());
838   for (const Edge* e : merge->in_edges()) {
839     if (e->src()->IsNextIteration()) {
840       if (*result != nullptr) {
841         return CreateMultipleNextIterationInputsError(merge);
842       }
843       *result = e;
844     }
845   }
846   return Status::OK();
847 }
848 
849 // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step
850 // does not contain `symbolic_predicate` as an inner (not top-level) operand
851 // then returns `Step`.  Otherwise returns nullptr.
DeduceStepPredicate(PredicateFactory * predicate_factory,Predicate * symbolic_predicate,Predicate * backedge_predicate)852 Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory,
853                                Predicate* symbolic_predicate,
854                                Predicate* backedge_predicate) {
855   CHECK(dynamic_cast<SymbolPredicate*>(symbolic_predicate));
856   if (backedge_predicate->kind() != Predicate::Kind::kAnd) {
857     return nullptr;
858   }
859 
860   std::vector<Predicate*> and_ops;
861   absl::Span<Predicate* const> recurrent_pred_ops =
862       backedge_predicate->GetOperands();
863 
864   bool found_sym = false;
865   for (Predicate* and_op : recurrent_pred_ops) {
866     // We want the `symbol_predicate` to be the one of the operands of
867     // `backedge_predicate`,
868     if (and_op == symbolic_predicate) {
869       found_sym = true;
870       continue;
871     }
872 
873     // but we don't want it to be present anywhere else in the formula.  E.g. we
874     // don't want the recurrent predicate to be
875     // symbol_predicate&(X|symbol_predicate).
876     bool found_sym_as_inner_operand = false;
877     auto has_self_as_inner_operand = [&](Predicate* p) {
878       if (p == symbolic_predicate) {
879         found_sym_as_inner_operand = true;
880         return true;  // Stop searching, we're done.
881       }
882 
883       // Continue searching.
884       return false;
885     };
886 
887     Predicate::Visit(and_op, has_self_as_inner_operand);
888     if (found_sym_as_inner_operand) {
889       return nullptr;
890     }
891     and_ops.push_back(and_op);
892   }
893 
894   return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr;
895 }
896 
GetFullFrame(const Node * n,absl::Span<const ControlFlowInfo> cfi_infos,std::vector<string> * frame)897 Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
898                     std::vector<string>* frame) {
899   int depth = 0;
900   for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource();
901        n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) {
902     frame->push_back(cfi_iter->frame_name);
903 
904     if (depth++ > 5000) {
905       return errors::Internal(
906           "Frame of depth > 5000:  Probably malformed graph or a bug in "
907           "BuildControlFlowInfo");
908     }
909   }
910 
911   return Status::OK();
912 }
913 }  // namespace
914 
HandleMerge(Node * n,std::vector<bool> * should_revisit)915 Status DeadnessAnalysisImpl::HandleMerge(Node* n,
916                                          std::vector<bool>* should_revisit) {
917   // Merge ignores deadness of its control inputs.  A merge that isn't the
918   // target of a backedge has is alive iff any of its data inputs are.  The
919   // liveness of a merge that is the target of a backedge can sometimes be
920   // represented using a AndRecurrencePredicate.  If neither apply, we represent
921   // the liveness of the merge symbolically.
922 
923   bool has_unvisited_backedge = false;
924   for (const Edge* e : n->in_edges()) {
925     if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
926       has_unvisited_backedge |= !predicate_map_.count(InputEdgeToTensorId(e));
927     }
928   }
929 
930   auto it = predicate_map_.find(TensorId(n->name(), 0));
931   if (it == predicate_map_.end()) {
932     if (has_unvisited_backedge) {
933       // We're visiting this merge for the first time and it has an unvisited
934       // backedge.
935       Predicate* input_data_pred;
936       TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
937           n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
938 
939       SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
940                    should_revisit);
941       return Status::OK();
942     }
943 
944     std::vector<Predicate*> input_preds;
945     TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
946 
947     // We're visiting this merge for the first time and it is a acyclic merge.
948     Predicate* input_data_pred =
949         predicate_factory_.MakeOrPredicate(input_preds);
950     SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
951                  should_revisit);
952     return Status::OK();
953   }
954 
955   if (it->second->kind() == Predicate::Kind::kSymbol) {
956     // Last time we visited this merge we only got a symbolic predicate because
957     // of an unvisited backedge.  Try to pattern match the predicate expression
958     // for that backedge (which should be visited now) into an and recurrence
959     // for the merge node.
960     const Edge* unique_backedge;
961     TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge));
962     if (unique_backedge) {
963       if (Predicate* step = DeduceStepPredicate(
964               &predicate_factory_, it->second,
965               predicate_map_[InputEdgeToTensorId(unique_backedge)])) {
966         // If the predicate for the backedge is "Sym&X" where "Sym" is the
967         // predicate for the merge then the merge has predicate {S,&,X} where S
968         // is the predicate for the merge ignoring the backedge.
969         std::vector<Predicate*> non_recurrent_inputs;
970         for (const Edge* e : n->in_edges()) {
971           if (e != unique_backedge) {
972             non_recurrent_inputs.push_back(
973                 predicate_map_[InputEdgeToTensorId(e)]);
974           }
975         }
976 
977         Predicate* start =
978             predicate_factory_.MakeOrPredicate(non_recurrent_inputs);
979         std::vector<string> frame;
980         TF_RETURN_IF_ERROR(GetFullFrame(n, control_flow_info_, &frame));
981         Predicate* and_rec = predicate_factory_.MakeAndRecurrencePredicate(
982             start, step, std::move(frame));
983         SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit);
984         return Status::OK();
985       }
986     }
987   }
988   return Status::OK();
989 }
990 
HandleRecv(Node * n,std::vector<bool> * should_revisit)991 Status DeadnessAnalysisImpl::HandleRecv(Node* n,
992                                         std::vector<bool>* should_revisit) {
993   // In addition to being alive or dead based on the inputs, a _Recv can also
994   // acquire a dead signal from a _Send.
995   std::vector<Predicate*> input_preds;
996   TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
997   Predicate* signal_is_alive;
998   TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
999       n, /*output_idx=*/0, /*must_be_true=*/false, &signal_is_alive));
1000   input_preds.push_back(signal_is_alive);
1001   SetPredicate(n, {0, Graph::kControlSlot},
1002                predicate_factory_.MakeAndPredicate(input_preds),
1003                should_revisit);
1004   return Status::OK();
1005 }
1006 
HandleGeneric(Node * n,std::vector<bool> * should_revisit)1007 Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
1008                                            std::vector<bool>* should_revisit) {
1009   // Generally nodes are alive iff all their inputs are alive.
1010   std::vector<Predicate*> input_preds;
1011   TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
1012   Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds);
1013   for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
1014     SetPredicate(n, output_idx, pred, should_revisit);
1015   }
1016   SetPredicate(n, Graph::kControlSlot, pred, should_revisit);
1017   return Status::OK();
1018 }
1019 
HandleNode(Node * n,std::vector<bool> * should_revisit)1020 Status DeadnessAnalysisImpl::HandleNode(Node* n,
1021                                         std::vector<bool>* should_revisit) {
1022   if (n->IsSwitch()) {
1023     TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
1024   } else if (n->IsMerge()) {
1025     TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit));
1026   } else if (n->IsControlTrigger()) {
1027     SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
1028                  nullptr);
1029   } else if (n->IsRecv() || n->IsHostRecv()) {
1030     TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit));
1031   } else if (n->IsNextIteration()) {
1032     TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1033   } else {
1034     TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit));
1035   }
1036   return Status::OK();
1037 }
1038 
Populate()1039 Status DeadnessAnalysisImpl::Populate() {
1040   std::vector<Node*> rpo;
1041   GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(),
1042                       /*edge_filter=*/[](const Edge& edge) {
1043                         return !edge.src()->IsNextIteration();
1044                       });
1045   return PopulateWithReversePostOrder(rpo);
1046 }
1047 
PopulateWithReversePostOrder(absl::Span<Node * const> rpo)1048 Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
1049     absl::Span<Node* const> rpo) {
1050   std::vector<string> unreachable_nodes;
1051   // Compute the loop structure of the graph.
1052   TF_RETURN_IF_ERROR(
1053       BuildControlFlowInfo(&graph_, &control_flow_info_, &unreachable_nodes));
1054 
1055   // Do some opportunistic error checking:
1056   if (!unreachable_nodes.empty()) {
1057     if (unreachable_nodes.size() > 5) {
1058       unreachable_nodes.erase(unreachable_nodes.begin() + 5,
1059                               unreachable_nodes.end());
1060     }
1061 
1062     return errors::InvalidArgument(
1063         "Found unreachable nodes, most likely source and sink nodes not "
1064         "connected: ",
1065         absl::StrJoin(unreachable_nodes, ", "));
1066   }
1067 
1068   // This an abstract interpretation over the deadness propagation semantics of
1069   // the graph executor.
1070   //
1071   // We iterate over the graph twice, each time in RPO.  On the first iteration
1072   // merge nodes with backedges are mapped to symbolic predicates.  On the
1073   // second iteration we use the predicates assigned to the backedges in the
1074   // previous iteration to infer a more precise predicate for the backedge merge
1075   // nodes and all the nodes that transitively use it.
1076   //
1077   // We don't track the output indices for should_revisit.  Instead, putting a
1078   // node in `should_revisit` denotes that the deadness flowing out from any
1079   // output from said node may have changed.  This is fine; only switches
1080   // propagate different deadness along different output edges, and since the
1081   // delta is solely due to the input *values* (and not input deadness), the
1082   // delta should not change in the second iteration.
1083   std::vector<bool> should_revisit;
1084   should_revisit.resize(graph_.num_node_ids());
1085   for (Node* n : rpo) {
1086     VLOG(4) << "Visiting " << n->name();
1087     TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr));
1088     if (n->IsNextIteration()) {
1089       // If this is a backedge for a merge node then remember to reprocess the
1090       // merge the next time we run.
1091       for (const Edge* e : n->out_edges()) {
1092         if (e->dst()->IsMerge()) {
1093           should_revisit[e->dst()->id()] = true;
1094         }
1095       }
1096     }
1097   }
1098 
1099   for (Node* n : rpo) {
1100     // The nodes added to should_revisit in the previous loop need to be
1101     // revisited now.  Reprocesing these initial nodes may add *their* consumers
1102     // to should_revisit, and these newly added nodes will also be processed by
1103     // this very same loop.  Since we're traversing the graph in reverse post
1104     // order (producers before consumers) and HandleNode(n) can only ever add
1105     // n's consumers to should_revisit, we won't "miss" an addition to
1106     // should_revisit.
1107     if (should_revisit[n->id()]) {
1108       VLOG(4) << "Revisiting " << n->name();
1109       TF_RETURN_IF_ERROR(HandleNode(n, &should_revisit));
1110     }
1111   }
1112 
1113   return Status::OK();
1114 }
1115 
HasInputsWithMismatchingDeadness(const Node & node)1116 bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
1117   CHECK(!node.IsMerge());
1118 
1119   if (vlog_) {
1120     VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")";
1121   }
1122 
1123   Predicate* pred = nullptr;
1124   for (const Edge* edge : node.in_edges()) {
1125     auto it = predicate_map_.find(InputEdgeToTensorId(edge));
1126     CHECK(it != predicate_map_.end());
1127     if (vlog_) {
1128       VLOG(2) << "  " << InputEdgeToTensorId(edge).ToString() << ": "
1129               << it->second->ToString();
1130     }
1131 
1132     // Today we just compare the predicates for equality (with some
1133     // canonicalization/simplification happening before) but we could be more
1134     // sophisticated here if need be.  Comparing pointers is sufficient because
1135     // we intern Predicate instances by their content.
1136     if (pred != nullptr && pred != it->second) {
1137       if (vlog_) {
1138         VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
1139                 << ") -> true";
1140       }
1141       return true;
1142     }
1143     pred = it->second;
1144   }
1145 
1146   if (vlog_) {
1147     VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
1148             << ") -> false";
1149   }
1150 
1151   return false;
1152 }
1153 
Print() const1154 void DeadnessAnalysisImpl::Print() const {
1155   std::vector<TensorId> tensor_ids;
1156   for (const auto& kv_pair : predicate_map_) {
1157     tensor_ids.push_back(kv_pair.first);
1158   }
1159 
1160   std::sort(tensor_ids.begin(), tensor_ids.end());
1161 
1162   for (TensorId tensor_id : tensor_ids) {
1163     auto it = predicate_map_.find(tensor_id);
1164     CHECK(it != predicate_map_.end()) << tensor_id.ToString();
1165     VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString();
1166   }
1167 }
1168 
1169 }  // namespace
1170 
~DeadnessAnalysis()1171 DeadnessAnalysis::~DeadnessAnalysis() {}
1172 
Run(const Graph & graph,std::unique_ptr<DeadnessAnalysis> * result)1173 /*static*/ Status DeadnessAnalysis::Run(
1174     const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
1175   std::unique_ptr<DeadnessAnalysisImpl> analysis(
1176       new DeadnessAnalysisImpl(&graph));
1177   TF_RETURN_IF_ERROR(analysis->Populate());
1178 
1179   if (VLOG_IS_ON(2)) {
1180     analysis->Print();
1181   }
1182 
1183   *result = std::move(analysis);
1184   return Status::OK();
1185 }
1186 
1187 absl::flat_hash_map<TensorId, string, TensorId::Hasher>
PredicateMapAsString() const1188 DeadnessAnalysisImpl::PredicateMapAsString() const {
1189   absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
1190   std::vector<TensorId> tensor_ids;
1191   for (const auto& kv_pair : predicate_map_) {
1192     CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
1193   }
1194   return result;
1195 }
1196 
1197 namespace deadness_analysis_internal {
ComputePredicates(const Graph & graph,PredicateMapTy * out_predicate_map)1198 Status ComputePredicates(const Graph& graph,
1199                          PredicateMapTy* out_predicate_map) {
1200   DeadnessAnalysisImpl impl(&graph);
1201   TF_RETURN_IF_ERROR(impl.Populate());
1202   *out_predicate_map = impl.PredicateMapAsString();
1203   return Status::OK();
1204 }
1205 
ComputePredicates(const Graph & graph,absl::Span<Node * const> reverse_post_order,PredicateMapTy * out_predicate_map)1206 Status ComputePredicates(const Graph& graph,
1207                          absl::Span<Node* const> reverse_post_order,
1208                          PredicateMapTy* out_predicate_map) {
1209   DeadnessAnalysisImpl impl(&graph);
1210   TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
1211   *out_predicate_map = impl.PredicateMapAsString();
1212   return Status::OK();
1213 }
1214 }  // namespace deadness_analysis_internal
1215 
1216 }  // namespace tensorflow
1217