1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
18 
19 #include <type_traits>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
25 #include "tensorflow/core/util/ptr_util.h"
26 
27 namespace xla {
28 
29 // IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a
30 // gather from another array.  It does this by mapping HLO instructions to
31 // instances of IndexedArrayAnalysis::Array, which can be inspected to discover
32 // whether said HLO is equivalent to a gather.
33 class IndexedArrayAnalysis {
34  public:
35   // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array.
36   // Array really just a sum type of the classes that inherit from it.  The
37   // meaning of each of the subtypes is documented on the subtype declaration.
38   //
39   // Array instances are immutable once created.
40   class Array {
41    public:
42     enum Kind {
43       kUnknown,
44       kConstant,
45       kReshaped,
46       kScalarIndexedConstant,
47       kScalarIndexed
48     };
49 
50     virtual Kind kind() const = 0;
51     virtual const Shape& shape() const = 0;
52 
53     // Does a checked downcast from `Array` to `T` which must be one of its
54     // subtypes.
55     template <typename T>
as()56     T* as() {
57       static_assert((std::is_base_of<Array, T>::value),
58                     "target type not derived from source type");
59       // We skip the CHECK and hence the dynamic_cast if RTTI is disabled.
60 #if !defined(__GNUC__) || defined(__GXX_RTTI)
61       CHECK_NE(dynamic_cast<T*>(this), nullptr);
62 #endif  // !defined(__GNUC__) || defined(__GXX_RTTI)
63 
64       return static_cast<T*>(this);
65     }
66 
67     virtual ~Array() = default;
68 
69     Array& operator=(const Array& other) = delete;
70   };
71 
72   // Represents an HLO instruction that was not analyzable by this
73   // IndexedArrayAnalysis.  Instances of UnknownArray just wrap an existing
74   // HloInstruction.
75   class UnknownArray : public Array {
76    public:
kind()77     Kind kind() const override { return kUnknown; }
shape()78     const Shape& shape() const override { return instruction().shape(); }
instruction()79     const HloInstruction& instruction() const { return instruction_; }
80 
81    private:
UnknownArray(const HloInstruction * instr)82     explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {}
83 
84     const HloInstruction& instruction_;
85 
86     friend class IndexedArrayAnalysis;
87   };
88 
89   // Represents a constant value.  This constant value may be present in the HLO
90   // module being analyzed, or it could have been created on the fly by the
91   // analysis.
92   class ConstantArray : public Array {
93    public:
kind()94     Kind kind() const override { return kConstant; }
shape()95     const Shape& shape() const override { return literal()->shape(); }
literal()96     const Literal* literal() const { return literal_; }
97 
98    private:
ConstantArray(const Literal * literal)99     explicit ConstantArray(const Literal* literal) : literal_(literal) {}
100     const Literal* literal_;
101 
102     friend class IndexedArrayAnalysis;
103   };
104 
105   // Represents an Array that is a reshape of another Array.
106   class ReshapedArray : public Array {
107    public:
kind()108     Kind kind() const override { return kReshaped; }
109 
110     // The array to reshape.
operand()111     Array* operand() const { return operand_; }
112 
113     // The output shape.
shape()114     const Shape& shape() const override { return shape_; }
115 
116    private:
ReshapedArray(Array * operand,Shape shape)117     explicit ReshapedArray(Array* operand, Shape shape)
118         : operand_(operand), shape_(shape) {}
119 
120     Array* operand_;
121     const Shape shape_;
122 
123     friend class IndexedArrayAnalysis;
124   };
125 
126   // ---------------------------------------------------------------------------
127   // Indexed Array Overview
128   // ---------------------------------------------------------------------------
129   //
130   // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this
131   // analysis.  ScalarIndexedConstantArray is just a specialization of
132   // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this
133   // overview.
134   //
135   // A ScalarIndexedArray represents an array that can be computed by indexing
136   // into a "source" array using an "indices" tensor.  A simple example is a
137   // gather operation gathering 12 rows out of a [100,100] matrix -- such an
138   // operation will be represented by an instance of a ScalarIndexedArray with
139   // the [100,100] matrix as the "source" array and the [12]-shaped indices
140   // array as the "indices" tensor.  The ScalarIndexedArray operation itself
141   // will be of shape [12,100] (assuming we were gathering with axis=0).
142   //
143   // Gather operations are not the only operation that maps to
144   // ScalarIndexedArray instances (if that were true there would be little point
145   // in having a separate analysis).  We can often infer ScalarIndexedArrays for
146   // other operations too.  For instance, consider:
147   //
148   //   %source = f32[100,100] constant
149   //   %indices = s32[12] ...
150   //   %gather = f32[12,100] ... gather from %source using %indices at axis 0
151   //   %dot = dot(%gather, other_constant) [canonical contracting dims]
152   //
153   // The dot operation itself is also a ScalarIndexedArray with source =
154   // dot(constant, other_constant) and indices = %indices.  A reshape of %gather
155   // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately
156   // reshaped constant and indices = %indices.
157 
158   // Represents the result of a gather operation.  This gather operation may
159   // explicitly be present in the HLO module being analyzed, or it could have
160   // been created on the fly by the analysis.
161   //
162   // An instance of ScalarIndexedArray represents a array whose I'th element can
163   // be mapped to the J'th element of the `source` array (where I and J are
164   // multidimensional indices) in this way:
165   //
166   //   I' = remove components at positions `output_dims` from I
167   //   G' = remove components not at positions `output_dims` from I
168   //   T  = indices[G']
169   //   J  = I' with T inserted at position `source_dim`
170   //
171   // For example, if source is of shape [11,13,17,19], indices is of shape
172   // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of
173   // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the
174   // input index [B,D,indices[A,C],E].
175   class ScalarIndexedArray : public Array {
176    public:
kind()177     Kind kind() const override { return kScalarIndexed; }
shape()178     const Shape& shape() const override { return shape_; }
179 
source()180     Array* source() const { return source_; }
indices()181     Array* indices() const { return indices_; }
182 
183     // `source_dim` is the dimension in the source array that is being indexed
184     // over using indices from the `indices` array.  See the class documentation
185     // and the overview for more details.
source_dim()186     int64 source_dim() const { return source_dim_; }
187 
188     // `output_dims` are the dimensions in the output array that are being used
189     // to compute an index into the `indices` array.  See the class
190     // documentation and the overview for more details.
output_dims()191     absl::Span<const int64> output_dims() const { return output_dims_; }
192 
193    private:
ScalarIndexedArray(Array * source,Array * indices,int64 source_dim,std::vector<int64> output_dims,Shape shape)194     explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim,
195                                 std::vector<int64> output_dims, Shape shape)
196         : source_(source),
197           indices_(indices),
198           source_dim_(source_dim),
199           output_dims_(std::move(output_dims)),
200           shape_(std::move(shape)) {}
201 
202     Array* source_;
203     Array* indices_;
204     int64 source_dim_;
205     std::vector<int64> output_dims_;
206     Shape shape_;
207 
208     friend class IndexedArrayAnalysis;
209   };
210 
211   // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to
212   // have a ConstantArray instance as the source.  This is an ergonomic
213   // concession -- in theory it is possible to just keep ScalarIndexedArray and
214   // check source()->kind().
215   class ScalarIndexedConstantArray : public ScalarIndexedArray {
216    public:
kind()217     Kind kind() const override { return kScalarIndexedConstant; }
218 
literal()219     const Literal& literal() const {
220       return *source()->as<ConstantArray>()->literal();
221     }
222 
223    private:
ScalarIndexedConstantArray(Array * source,Array * indices,int64 source_dim,std::vector<int64> output_dims,Shape shape)224     explicit ScalarIndexedConstantArray(Array* source, Array* indices,
225                                         int64 source_dim,
226                                         std::vector<int64> output_dims,
227                                         Shape shape)
228         : ScalarIndexedArray(source, indices, source_dim,
229                              std::move(output_dims), std::move(shape)) {
230       CHECK(dynamic_cast<ConstantArray*>(source));
231     }
232 
233     friend class IndexedArrayAnalysis;
234   };
235 
236   // Returns an Array instance for `instr`.  The IndexedArrayAnalysis instance
237   // keeps ownership of the returned Array instance.
238   //
239   // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO
240   // instructions to IndexedArrayAnalysis::Array instances.  This entire cache
241   // becomes stale and may cause the analysis to return incorrect results if any
242   // transitive operand (stopping at the containing computation) is modified for
243   // any HLO instruction on which GetArrayFor has been invoked.
244   //
245   // NB!  By inspecting the implementation, you may be able to infer a stronger
246   // caching guarantee than what is mentioned above.  Nevertheless, what is
247   // stated above is the contract.
248   StatusOr<Array*> GetArrayFor(const HloInstruction* instr);
249 
250   // Pretty-prints the expression rooted at `root`.
251   string ToString(Array* root, bool print_constants = false);
252 
253  private:
254   // Helper function that ensures that every HLO instruction that is
255   // transitively used by `root` has an entry in `cache_`.
256   Status TraverseAndPopulateCache(const HloInstruction* root);
257 
258   // Creates an Array instance for `instr` under the assumption that all
259   // operations of `instr` are present in `cache_`.
260   StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr);
261 
262   StatusOr<Array*> ComputeArrayForConstant(const Literal& literal);
263 
264   StatusOr<Array*> ComputeArrayForGather(
265       const Shape& shape, const GatherDimensionNumbers& dim_numbers,
266       absl::Span<const int64> slice_sizes, Array* source, Array* indices);
267 
268   StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
269       const Shape& shape, const DotDimensionNumbers& dim_numbers,
270       const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
271       ConstantArray* rhs);
272 
273   StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
274       const Shape& shape, const DotDimensionNumbers& dim_numbers,
275       const PrecisionConfig& precision_config, ConstantArray* lhs,
276       ScalarIndexedConstantArray* rhs);
277 
278   StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
279                                       const DotDimensionNumbers& dim_numbers,
280                                       const PrecisionConfig& precision_config,
281                                       Array* lhs, Array* rhs);
282 
283   // This tries to fold a ScalarIndexedArray which has another
284   // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
285   // ScalarIndexedArray as indices.  If `source` happened to be a
286   // ScalarIndexedConstantArray this can result in an expression that is more
287   // canonical.
288   //
289   // As an example, consider a gather operation, G0, gathering 7 elements from
290   // an array "Arr" of shape [100] resulting in an array of shape [7], and a
291   // second gather operation, G1, which gathers 3 elements out of the result of
292   // G0 resulting in an array of shape [3].  Let the indices uses by G0 be I0
293   // (of shape [7]) and the indices used by G1 be I1 (of shape [3]).  We can
294   // instead rewrite G1 to gather directly from "Arr" with the three indices
295   // from I0 as per I1.  In other words, we can rewrite:
296   //
297   //    G0 = [Arr[i] for i in I0]
298   //    G1 = [G0[i]  for i in I1]
299   //
300   // into
301   //
302   //    I2 = [I0[i]  for i in I1]
303   //    G1 = [Arr[i] for i in I2]
304   StatusOr<ScalarIndexedArray*> FoldGatherOfGather(
305       ScalarIndexedArray* source, Array* indices, int64 source_dim,
306       absl::Span<const int64> output_dims, Shape shape);
307 
308   // Reshapes a scalar-indexed node to remove the degenerate dimensions in its
309   // output.  The result is always a scalar-indexed node.
310   StatusOr<ScalarIndexedArray*> ReshapeToRemoveDegenerateDims(
311       ScalarIndexedArray* operand);
312 
313   // Reshapes a scalar-indexed node such that the result has the degenerate
314   // dimensions `degenerate_dims`.  The result is always a scalar-indexed node.
315   StatusOr<ScalarIndexedArray*> ReshapeToAddDegenerateDims(
316       ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims);
317 
318   StatusOr<ScalarIndexedArray*> FoldReshapeOfGather(
319       const Shape& shape, ScalarIndexedConstantArray* operand);
320   StatusOr<ScalarIndexedArray*> FoldReshapeOfGatherNoDegenerateDims(
321       const Shape& shape, ScalarIndexedConstantArray* scalar_indexed);
322   StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand);
323 
324   StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
325                                                       Array* lhs, Array* rhs);
326   StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
327                                                      Array* operand);
328 
329   template <typename T, typename... Args>
Construct(Args &&...args)330   T* Construct(Args&&... args) {
331     T* new_tensor = new T(std::forward<Args>(args)...);
332     owned_tensors_.push_back(std::unique_ptr<T>(new_tensor));
333     return new_tensor;
334   }
335 
ConstructScalarIndexedArray(Array * source,Array * indices,int64 source_dim,std::vector<int64> output_dims,Shape shape)336   ScalarIndexedArray* ConstructScalarIndexedArray(
337       Array* source, Array* indices, int64 source_dim,
338       std::vector<int64> output_dims, Shape shape) {
339     if (source->kind() == Array::kConstant) {
340       return Construct<ScalarIndexedConstantArray>(source, indices, source_dim,
341                                                    std::move(output_dims),
342                                                    std::move(shape));
343     } else {
344       return Construct<ScalarIndexedArray>(source, indices, source_dim,
345                                            std::move(output_dims),
346                                            std::move(shape));
347     }
348   }
349 
TakeOwnership(Literal literal)350   Literal* TakeOwnership(Literal literal) {
351     owned_literals_.push_back(std::move(literal));
352     return &owned_literals_.back();
353   }
354 
TakeOwnership(StatusOr<Literal> literal_or_error)355   StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
356     TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
357     owned_literals_.push_back(std::move(literal));
358     return &owned_literals_.back();
359   }
360 
361   std::vector<std::unique_ptr<Array>> owned_tensors_;
362   std::vector<Literal> owned_literals_;
363   absl::flat_hash_map<const HloInstruction*, Array*> cache_;
364 };
365 
366 // A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
367 // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
368 // unconditionally add to the regular HLO pass pipeline.
369 class IndexedArrayAnalysisPrinterPass : public HloModulePass {
370  public:
371   absl::string_view name() const override;
372   StatusOr<bool> Run(HloModule* module) override;
373 };
374 
375 }  // namespace xla
376 
377 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
378