1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
18 
19 #include <cstdio>
20 #include <list>
21 #include <vector>
22 
23 #include "absl/base/casts.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 // A class for representing reachability between HloInstructions.
37 //
38 // It has an adjacency matrix and it is up to the user of the class to set the
39 // adjacency matrix such that it represents reachability, i.e. such that it is
40 // transitive. That the graph be transitive is thus not an invariant of this
41 // class, but it is required for the name of the class and its methods to make
42 // sense.
43 class HloReachabilityMap {
44  public:
45   // Sets up a graph with no edges and where the nodes correspond to the given
46   // instructions.
47   explicit HloReachabilityMap(
48       absl::Span<const HloInstruction* const> instructions);
49 
50   // Computes and returns the reachability between HLO instructions in the
51   // computation. The returned HloReachabilityMap is constructed such that
52   // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a
53   // directed path (from producer to consumer) from 'a' to 'b'. Both data
54   // dependencies (operands) and control dependencies are considered for
55   // reachability. Trivially an instruction is reachable from itself.
56   static std::unique_ptr<HloReachabilityMap> Build(
57       const HloComputation* computation);
58 
59   // Set the reachability set of 'instruction' to the union of the reachability
60   // sets of 'inputs'. Upon return, IsReachable(x, instruction) where
61   // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true
62   // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from
63   // itself. Returns whether the reachability set of 'instruction' changed.
64   //
65   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
66   // vector in the internal graph of this HloReachabilityMap for the given
67   // instruction and does not transitively update any other part of the
68   // adjacency matrix.
69   bool SetReachabilityToUnion(absl::Span<const HloInstruction* const> inputs,
70                               const HloInstruction* instruction);
71 
72   // As above, but faster because it does not check if the reachability changed.
73   void FastSetReachabilityToUnion(
74       absl::Span<const HloInstruction* const> inputs,
75       const HloInstruction* instruction);
76 
77   // Sets entry so that IsReachable(a, b) will return true
78   //
79   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
80   // matrix in the internal graph of this HloReachabilityMap to have an edge
81   // from a to b and does not transitively update any other part of the
82   // adjacency matrix.
83   void SetReachable(const HloInstruction* a, const HloInstruction* b);
84 
85   // Updates the given reachability map after the immediate predecessor set
86   // (operands and control predecessors) of 'instruction' has changed.
87   void UpdateReachabilityThroughInstruction(const HloInstruction* instruction);
88 
89   // Returns true if "b" is reachable from "a"
90   //
91   // Note that this function only correctly answers queries about reachability
92   // if the set of edges that have been provided to this class are transitive.
93   bool IsReachable(const HloInstruction* a, const HloInstruction* b) const;
94 
95   // Returns true if "b" is reachable from "a" or "a" is reachable from "b"
96   //
97   // Note that this function only correctly answers queries about reachability
98   // if the set of edges that have been provided to this class are transitive.
99   bool IsConnected(const HloInstruction* a, const HloInstruction* b) const;
100 
101   // Checks if an instruction is in the Reachability map.
IsPresent(const HloInstruction * a)102   bool IsPresent(const HloInstruction* a) const {
103     return indices_.contains(GetKey(a));
104   }
105 
106  private:
107   // A bit-vector implementation specialized for this use case which provides a
108   // fast bitwise OR operation not available in tensorflow::gtl::BitMap.
109   class BitVector {
110    public:
111     BitVector() = default;
BitVector(size_t size)112     BitVector(size_t size)
113         : size_(size), vector_((size + kBits - 1) / kBits, 0) {}
114 
115     // Return the bit at the given index.
Get(size_t index)116     bool Get(size_t index) const {
117       DCHECK(index >= 0 && index < size_);
118       return vector_[index / kBits] & (1ull << (index % kBits));
119     }
120 
121     // Set the bit at the given index.
Set(size_t index)122     void Set(size_t index) {
123       DCHECK(index >= 0 && index < size_);
124       vector_[index / kBits] |= 1ull << (index % kBits);
125     }
126 
127     // Set this bitvector to the Logical OR of this bitvector and 'other'.
OrWith(const BitVector & other)128     void OrWith(const BitVector& other) {
129       for (size_t i = 0; i < vector_.size(); ++i) {
130         vector_[i] |= other.vector_[i];
131       }
132     }
133 
134     // Set the bitvector to all zeros.
SetToZero()135     void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); }
136 
137     bool operator==(const BitVector& other) const {
138       return vector_ == other.vector_;
139     }
140     bool operator!=(const BitVector& other) const {
141       return vector_ != other.vector_;
142     }
143 
144    private:
145     using Word = uint64;
146     static const size_t kBits = 64;
147 
148     // Number of bits in the bitvector.
149     size_t size_;
150 
151     std::vector<Word> vector_;
152   };
153 
154   // Return the bitvector storing the reachability-to of the given instruction.
GetBitVector(const HloInstruction * instruction)155   const BitVector& GetBitVector(const HloInstruction* instruction) const {
156     return bit_vectors_[GetIndex(instruction)];
157   }
GetBitVector(const HloInstruction * instruction)158   BitVector& GetBitVector(const HloInstruction* instruction) {
159     return bit_vectors_[GetIndex(instruction)];
160   }
161 
162   // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion.
163   void SetReachabilityToUnionHelper(
164       absl::Span<const HloInstruction* const> inputs,
165       const HloInstruction* instruction, BitVector* bit_vector);
166 
GetKey(const HloInstruction * instruction)167   uint64 GetKey(const HloInstruction* instruction) const {
168     uint64 unique_id = absl::bit_cast<uint32>(instruction->unique_id());
169     uint64 module_id =
170         absl::bit_cast<uint32>(instruction->parent()->parent()->unique_id());
171     return (module_id << 32) | unique_id;
172   }
173   // Return the index of the given instruction. The value is used to index into
174   // the vector of BitVectors and the BitVectors themselves.
GetIndex(const HloInstruction * instruction)175   int GetIndex(const HloInstruction* instruction) const {
176     return FindOrDie(indices_, GetKey(instruction));
177   }
178 
179   // The number of instructions in the reachability map.
180   const size_t size_;
181 
182   // Dense assignment from HloInstruction::unique_id to number. These numbers
183   // index into the bit_vectors_ vector and into the bits within a BitVector.
184   absl::flat_hash_map<uint64, int> indices_;
185 
186   // Bitvectors holding the reachability to each instruction. The bit vector for
187   // instruction X includes ones for each instruction which X is reachable from.
188   std::vector<BitVector> bit_vectors_;
189 
190   // A temporary used by SetReachabilityToUnion to avoid an allocation with each
191   // call to the method.
192   BitVector tmp_bit_vector_;
193 };
194 
195 }  // namespace xla
196 
197 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
198