1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
18 
19 #include "tensorflow/core/graph/graph.h"
20 #include "tensorflow/stream_executor/lib/statusor.h"
21 
22 namespace tensorflow {
23 
24 // This analyzes a TensorFlow graph to identify nodes which may have partially
25 // dead inputs (i.e. these nodes may have some dead inputs and some alive
26 // inputs).
27 //
28 // For example, the ADD node in the following graph
29 //
30 //      V0  PRED0    V1  PRED1
31 //       |    |       |    |
32 //       v    v       v    v
33 //       SWITCH       SWITCH
34 //          |            |
35 //          +---+   + ---+
36 //              |   |
37 //              v   v
38 //               ADD
39 //
40 // can have its inputs independently dead or alive based on the runtime values
41 // of PRED0 and PRED1.
42 //
43 // It is tempting to call this a liveness analysis but I avoided that because
44 // "liveness" already has other connotations.
45 class DeadnessAnalysis {
46  public:
47   // An opaque representation of a predicate.  DeadnessPredicate
48   // instances that compare equal via operator== represent predicates
49   // that always evaluate to the same value.
50   struct DeadnessPredicate {
51    public:
52     DeadnessPredicate(const DeadnessPredicate&) = default;
53     DeadnessPredicate(DeadnessPredicate&&) = default;
54 
55     DeadnessPredicate& operator=(const DeadnessPredicate&) = default;
56     DeadnessPredicate& operator=(DeadnessPredicate&&) = default;
57 
58     bool operator==(const DeadnessPredicate& other) const {
59       return other.pred_ == pred_;
60     }
61 
62     bool operator!=(const DeadnessPredicate& other) const {
63       return other.pred_ != pred_;
64     }
65 
66    private:
DeadnessPredicateDeadnessPredicate67     explicit DeadnessPredicate(void* pred) : pred_(pred) {}
68 
69     // This is really a Predicate*, but we don't want to expose that
70     // implementation detail to our clients.  `pred_` has pointer equality so we
71     // can just compare the pointer in operator== and operator!=.
72     void* pred_;
73 
74     friend class DeadnessAnalysis;
75   };
76 
77   virtual se::port::StatusOr<DeadnessPredicate> GetPredicateFor(
78       Node* n, int oidx) const = 0;
79 
80   // Prints out the internal state of this instance.  For debugging purposes
81   // only.
82   virtual void Print() const = 0;
83   virtual ~DeadnessAnalysis();
84 
85   string DebugString(DeadnessPredicate predicate) const;
86 
87   // Run the deadness analysis over `graph` and returns an error or a populated
88   // instance of DeadnessAnalysis in `result`.
89   static Status Run(const Graph& graph,
90                     std::unique_ptr<DeadnessAnalysis>* result);
91 
92  protected:
MakeDeadnessPredicate(void * pred)93   static DeadnessPredicate MakeDeadnessPredicate(void* pred) {
94     return DeadnessPredicate(pred);
95   }
96 };
97 
98 }  // namespace tensorflow
99 
100 #endif  // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
101