1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
18 
19 #include <cstdio>
20 #include <list>
21 #include <vector>
22 
23 #include "absl/base/casts.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 // A class for representing reachability between HloInstructions.
37 //
38 // It has an adjacency matrix and it is up to the user of the class to set the
39 // adjacency matrix such that it represents reachability, i.e. such that it is
40 // transitive. That the graph be transitive is thus not an invariant of this
41 // class, but it is required for the name of the class and its methods to make
42 // sense.
43 class HloReachabilityMap {
44  public:
45   // Sets up a graph with no edges and where the nodes correspond to the given
46   // instructions.
47   explicit HloReachabilityMap(
48       absl::Span<const HloInstruction* const> instructions);
49 
50   // Computes and returns the reachability between HLO instructions in the
51   // computation. The returned HloReachabilityMap is constructed such that
52   // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a
53   // directed path (from producer to consumer) from 'a' to 'b'. Both data
54   // dependencies (operands) and control dependencies are considered for
55   // reachability. Trivially an instruction is reachable from itself.
56   static std::unique_ptr<HloReachabilityMap> Build(
57       const HloComputation* computation);
58 
59   // Similar to the above Build operation except that it tries to identify
60   // paths between instructions that do not contain control instructions
61   // and multiple operands, i.e., b is_reachable a == true iff
62   // b = f(f(f(f(f(a), constant), constant), constant).
63   // Further, the only ops allowed in a path are basic math operations such
64   // as add, sub, mul, div.
65   static std::unique_ptr<HloReachabilityMap> BuildWithRestrictions(
66       const HloComputation* computation,
67       absl::FunctionRef<void(const HloInstruction*,
68                              std::vector<HloInstruction*>*)>
69           add_dependencies);
70 
71   // Set the reachability set of 'instruction' to the union of the reachability
72   // sets of 'inputs'. Upon return, IsReachable(x, instruction) where
73   // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true
74   // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from
75   // itself. Returns whether the reachability set of 'instruction' changed.
76   //
77   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
78   // vector in the internal graph of this HloReachabilityMap for the given
79   // instruction and does not transitively update any other part of the
80   // adjacency matrix.
81   bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs,
82                               const HloInstruction* instruction);
83 
84   // As above, but faster because it does not check if the reachability changed.
85   void FastSetReachabilityToUnion(
86       absl::Span<const HloInstruction* const> inputs,
87       const HloInstruction* instruction);
88 
89   // An opaque index that clients can use to make repeated operations for the
90   // same instruction faster, by calling GetIndex once for the instruction,
91   // and then calling the variants of other interfaces that take Index arguments
92   // rather than HloInstruction* arguments.
93   struct Index {
94    private:
95     friend class HloReachabilityMap;
96 
97     // Index assigned for a particular instruction.  The value is used to index
98     // into the vector of BitVectors and the BitVectors themselves.
99     int v;
100   };
GetIndex(const HloInstruction * instruction)101   Index GetIndex(const HloInstruction* instruction) const {
102     Index i;
103     i.v = FindOrDie(indices_, GetKey(instruction));
104     return i;
105   }
106 
107   // Sets entry so that IsReachable(a, b) will return true
108   //
109   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
110   // matrix in the internal graph of this HloReachabilityMap to have an edge
111   // from a to b and does not transitively update any other part of the
112   // adjacency matrix.
SetReachable(const HloInstruction * a,const HloInstruction * b)113   void SetReachable(const HloInstruction* a, const HloInstruction* b) {
114     SetReachable(GetIndex(a), GetIndex(b));
115   }
116   void SetReachable(Index a, Index b);
117 
118   // Updates the given reachability map after the immediate predecessor set
119   // (operands and control predecessors) of 'instruction' has changed.
120   void UpdateReachabilityThroughInstruction(const HloInstruction* instruction);
121 
122   // Returns true if "b" is reachable from "a"
123   //
124   // Note that this function only correctly answers queries about reachability
125   // if the set of edges that have been provided to this class are transitive.
IsReachable(const HloInstruction * a,const HloInstruction * b)126   bool IsReachable(const HloInstruction* a, const HloInstruction* b) const {
127     return IsReachable(GetIndex(a), GetIndex(b));
128   }
IsReachable(Index a,Index b)129   bool IsReachable(Index a, Index b) const { return GetBitVector(b).Get(a.v); }
130 
131   // Returns true if "b" is reachable from "a" or "a" is reachable from "b"
132   //
133   // Note that this function only correctly answers queries about reachability
134   // if the set of edges that have been provided to this class are transitive.
IsConnected(const HloInstruction * a,const HloInstruction * b)135   bool IsConnected(const HloInstruction* a, const HloInstruction* b) const {
136     return IsConnected(GetIndex(a), GetIndex(b));
137   }
IsConnected(Index a,Index b)138   bool IsConnected(Index a, Index b) const {
139     return IsReachable(a, b) || IsReachable(b, a);
140   }
141 
142   // Checks if an instruction is in the Reachability map.
IsPresent(const HloInstruction * a)143   bool IsPresent(const HloInstruction* a) const {
144     return indices_.contains(GetKey(a));
145   }
146 
147   // Replace the instruction "original" with "replacement" in the reachability
148   // map.
149   void Replace(const HloInstruction* original,
150                const HloInstruction* replacement);
151 
152  private:
153   // A bit-vector implementation specialized for this use case which provides a
154   // fast bitwise OR operation not available in tensorflow::gtl::BitMap.
155   class BitVector {
156    public:
157     BitVector() = default;
BitVector(size_t size)158     BitVector(size_t size)
159         : size_(size), vector_((size + kBits - 1) / kBits, 0) {}
160 
161     // Return the bit at the given index.
Get(size_t index)162     bool Get(size_t index) const {
163       DCHECK(index >= 0 && index < size_);
164       return vector_[index / kBits] & (1ull << (index % kBits));
165     }
166 
167     // Set the bit at the given index.
Set(size_t index)168     void Set(size_t index) {
169       DCHECK(index >= 0 && index < size_);
170       vector_[index / kBits] |= 1ull << (index % kBits);
171     }
172 
173     // Set this bitvector to the Logical OR of this bitvector and 'other'.
OrWith(const BitVector & other)174     void OrWith(const BitVector& other) {
175       for (size_t i = 0; i < vector_.size(); ++i) {
176         vector_[i] |= other.vector_[i];
177       }
178     }
179 
180     // Set the bitvector to all zeros.
SetToZero()181     void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); }
182 
183     bool operator==(const BitVector& other) const {
184       return vector_ == other.vector_;
185     }
186     bool operator!=(const BitVector& other) const {
187       return vector_ != other.vector_;
188     }
189 
190    private:
191     using Word = uint64;
192     static constexpr size_t kBits = 64;
193 
194     // Number of bits in the bitvector.
195     size_t size_;
196 
197     std::vector<Word> vector_;
198   };
199 
200   // Return the bitvector storing the reachability-to of the given instruction.
GetBitVector(const HloInstruction * instruction)201   const BitVector& GetBitVector(const HloInstruction* instruction) const {
202     return GetBitVector(GetIndex(instruction));
203   }
GetBitVector(const HloInstruction * instruction)204   BitVector& GetBitVector(const HloInstruction* instruction) {
205     return GetBitVector(GetIndex(instruction));
206   }
207 
GetBitVector(Index index)208   const BitVector& GetBitVector(Index index) const {
209     return bit_vectors_[index.v];
210   }
GetBitVector(Index index)211   BitVector& GetBitVector(Index index) { return bit_vectors_[index.v]; }
212 
213   // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
214   void SetReachabilityToUnionHelper(
215       absl::Span<const HloInstruction* const> inputs,
216       const HloInstruction* instruction, BitVector* bit_vector);
217 
GetKey(const HloInstruction * instruction)218   uint64 GetKey(const HloInstruction* instruction) const {
219     uint64 unique_id = absl::bit_cast<uint32>(instruction->unique_id());
220     uint64 module_id =
221         absl::bit_cast<uint32>(instruction->parent()->parent()->unique_id());
222     return (module_id << 32) | unique_id;
223   }
224   // Return the index of the given instruction.
GetIndexInternal(const HloInstruction * instruction)225   int GetIndexInternal(const HloInstruction* instruction) const {
226     return FindOrDie(indices_, GetKey(instruction));
227   }
228 
229   // The number of instructions in the reachability map.
230   const size_t size_;
231 
232   // Dense assignment from HloInstruction::unique_id to number. These numbers
233   // index into the bit_vectors_ vector and into the bits within a BitVector.
234   absl::flat_hash_map<uint64, int> indices_;
235 
236   // Bitvectors holding the reachability to each instruction. The bit vector for
237   // instruction X includes ones for each instruction which X is reachable from.
238   std::vector<BitVector> bit_vectors_;
239 
240   // A temporary used by SetReachabilityToUnion to avoid an allocation with each
241   // call to the method.
242   BitVector tmp_bit_vector_;
243 };
244 
245 }  // namespace xla
246 
247 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
248