1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ 18 19 #include <type_traits> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_module.h" 24 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 25 #include "tensorflow/core/util/ptr_util.h" 26 27 namespace xla { 28 29 // IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a 30 // gather from another array. It does this by mapping HLO instructions to 31 // instances of IndexedArrayAnalysis::Array, which can be inspected to discover 32 // whether said HLO is equivalent to a gather. 33 class IndexedArrayAnalysis { 34 public: 35 // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. 36 // Array really just a sum type of the classes that inherit from it. The 37 // meaning of each of the subtypes is documented on the subtype declaration. 38 // 39 // Array instances are immutable once created. 40 class Array { 41 public: 42 enum Kind { 43 kUnknown, 44 kConstant, 45 kReshaped, 46 kScalarIndexedConstant, 47 kScalarIndexed 48 }; 49 50 virtual Kind kind() const = 0; 51 virtual const Shape& shape() const = 0; 52 53 // Does a checked downcast from `Array` to `T` which must be one of its 54 // subtypes. 55 template <typename T> as()56 T* as() { 57 static_assert((std::is_base_of<Array, T>::value), 58 "target type not derived from source type"); 59 // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. 60 #if !defined(__GNUC__) || defined(__GXX_RTTI) 61 CHECK_NE(dynamic_cast<T*>(this), nullptr); 62 #endif // !defined(__GNUC__) || defined(__GXX_RTTI) 63 64 return static_cast<T*>(this); 65 } 66 67 virtual ~Array() = default; 68 69 Array& operator=(const Array& other) = delete; 70 }; 71 72 // Represents an HLO instruction that was not analyzable by this 73 // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing 74 // HloInstruction. 75 class UnknownArray : public Array { 76 public: kind()77 Kind kind() const override { return kUnknown; } shape()78 const Shape& shape() const override { return instruction().shape(); } instruction()79 const HloInstruction& instruction() const { return instruction_; } 80 81 private: UnknownArray(const HloInstruction * instr)82 explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} 83 84 const HloInstruction& instruction_; 85 86 friend class IndexedArrayAnalysis; 87 }; 88 89 // Represents a constant value. This constant value may be present in the HLO 90 // module being analyzed, or it could have been created on the fly by the 91 // analysis. 92 class ConstantArray : public Array { 93 public: kind()94 Kind kind() const override { return kConstant; } shape()95 const Shape& shape() const override { return literal()->shape(); } literal()96 const Literal* literal() const { return literal_; } 97 98 private: ConstantArray(const Literal * literal)99 explicit ConstantArray(const Literal* literal) : literal_(literal) {} 100 const Literal* literal_; 101 102 friend class IndexedArrayAnalysis; 103 }; 104 105 // Represents an Array that is a reshape of another Array. 106 class ReshapedArray : public Array { 107 public: kind()108 Kind kind() const override { return kReshaped; } 109 110 // The array to reshape. operand()111 Array* operand() const { return operand_; } 112 113 // The output shape. shape()114 const Shape& shape() const override { return shape_; } 115 116 private: ReshapedArray(Array * operand,Shape shape)117 explicit ReshapedArray(Array* operand, Shape shape) 118 : operand_(operand), shape_(shape) {} 119 120 Array* operand_; 121 const Shape shape_; 122 123 friend class IndexedArrayAnalysis; 124 }; 125 126 // --------------------------------------------------------------------------- 127 // Indexed Array Overview 128 // --------------------------------------------------------------------------- 129 // 130 // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this 131 // analysis. ScalarIndexedConstantArray is just a specialization of 132 // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this 133 // overview. 134 // 135 // A ScalarIndexedArray represents an array that can be computed by indexing 136 // into a "source" array using an "indices" tensor. A simple example is a 137 // gather operation gathering 12 rows out of a [100,100] matrix -- such an 138 // operation will be represented by an instance of a ScalarIndexedArray with 139 // the [100,100] matrix as the "source" array and the [12]-shaped indices 140 // array as the "indices" tensor. The ScalarIndexedArray operation itself 141 // will be of shape [12,100] (assuming we were gathering with axis=0). 142 // 143 // Gather operations are not the only operation that maps to 144 // ScalarIndexedArray instances (if that were true there would be little point 145 // in having a separate analysis). We can often infer ScalarIndexedArrays for 146 // other operations too. For instance, consider: 147 // 148 // %source = f32[100,100] constant 149 // %indices = s32[12] ... 150 // %gather = f32[12,100] ... gather from %source using %indices at axis 0 151 // %dot = dot(%gather, other_constant) [canonical contracting dims] 152 // 153 // The dot operation itself is also a ScalarIndexedArray with source = 154 // dot(constant, other_constant) and indices = %indices. A reshape of %gather 155 // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately 156 // reshaped constant and indices = %indices. 157 158 // Represents the result of a gather operation. This gather operation may 159 // explicitly be present in the HLO module being analyzed, or it could have 160 // been created on the fly by the analysis. 161 // 162 // An instance of ScalarIndexedArray represents a array whose I'th element can 163 // be mapped to the J'th element of the `source` array (where I and J are 164 // multidimensional indices) in this way: 165 // 166 // I' = remove components at positions `output_dims` from I 167 // G' = remove components not at positions `output_dims` from I 168 // T = indices[G'] 169 // J = I' with T inserted at position `source_dim` 170 // 171 // For example, if source is of shape [11,13,17,19], indices is of shape 172 // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of 173 // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the 174 // input index [B,D,indices[A,C],E]. 175 class ScalarIndexedArray : public Array { 176 public: kind()177 Kind kind() const override { return kScalarIndexed; } shape()178 const Shape& shape() const override { return shape_; } 179 source()180 Array* source() const { return source_; } indices()181 Array* indices() const { return indices_; } 182 183 // `source_dim` is the dimension in the source array that is being indexed 184 // over using indices from the `indices` array. See the class documentation 185 // and the overview for more details. source_dim()186 int64 source_dim() const { return source_dim_; } 187 188 // `output_dims` are the dimensions in the output array that are being used 189 // to compute an index into the `indices` array. See the class 190 // documentation and the overview for more details. output_dims()191 absl::Span<const int64> output_dims() const { return output_dims_; } 192 193 private: ScalarIndexedArray(Array * source,Array * indices,int64 source_dim,std::vector<int64> output_dims,Shape shape)194 explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, 195 std::vector<int64> output_dims, Shape shape) 196 : source_(source), 197 indices_(indices), 198 source_dim_(source_dim), 199 output_dims_(std::move(output_dims)), 200 shape_(std::move(shape)) {} 201 202 Array* source_; 203 Array* indices_; 204 int64 source_dim_; 205 std::vector<int64> output_dims_; 206 Shape shape_; 207 208 friend class IndexedArrayAnalysis; 209 }; 210 211 // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to 212 // have a ConstantArray instance as the source. This is an ergonomic 213 // concession -- in theory it is possible to just keep ScalarIndexedArray and 214 // check source()->kind(). 215 class ScalarIndexedConstantArray : public ScalarIndexedArray { 216 public: kind()217 Kind kind() const override { return kScalarIndexedConstant; } 218 literal()219 const Literal& literal() const { 220 return *source()->as<ConstantArray>()->literal(); 221 } 222 223 private: ScalarIndexedConstantArray(Array * source,Array * indices,int64 source_dim,std::vector<int64> output_dims,Shape shape)224 explicit ScalarIndexedConstantArray(Array* source, Array* indices, 225 int64 source_dim, 226 std::vector<int64> output_dims, 227 Shape shape) 228 : ScalarIndexedArray(source, indices, source_dim, 229 std::move(output_dims), std::move(shape)) { 230 CHECK(dynamic_cast<ConstantArray*>(source)); 231 } 232 233 friend class IndexedArrayAnalysis; 234 }; 235 236 // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance 237 // keeps ownership of the returned Array instance. 238 // 239 // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO 240 // instructions to IndexedArrayAnalysis::Array instances. This entire cache 241 // becomes stale and may cause the analysis to return incorrect results if any 242 // transitive operand (stopping at the containing computation) is modified for 243 // any HLO instruction on which GetArrayFor has been invoked. 244 // 245 // NB! By inspecting the implementation, you may be able to infer a stronger 246 // caching guarantee than what is mentioned above. Nevertheless, what is 247 // stated above is the contract. 248 StatusOr<Array*> GetArrayFor(const HloInstruction* instr); 249 250 // Pretty-prints the expression rooted at `root`. 251 string ToString(Array* root, bool print_constants = false); 252 253 private: 254 // Helper function that ensures that every HLO instruction that is 255 // transitively used by `root` has an entry in `cache_`. 256 Status TraverseAndPopulateCache(const HloInstruction* root); 257 258 // Creates an Array instance for `instr` under the assumption that all 259 // operations of `instr` are present in `cache_`. 260 StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr); 261 262 StatusOr<Array*> ComputeArrayForConstant(const Literal& literal); 263 264 StatusOr<Array*> ComputeArrayForGather( 265 const Shape& shape, const GatherDimensionNumbers& dim_numbers, 266 absl::Span<const int64> slice_sizes, Array* source, Array* indices); 267 268 StatusOr<Array*> ComputeArrayForDotWithIndexedLhs( 269 const Shape& shape, const DotDimensionNumbers& dim_numbers, 270 const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, 271 ConstantArray* rhs); 272 273 StatusOr<Array*> ComputeArrayForDotWithIndexedRhs( 274 const Shape& shape, const DotDimensionNumbers& dim_numbers, 275 const PrecisionConfig& precision_config, ConstantArray* lhs, 276 ScalarIndexedConstantArray* rhs); 277 278 StatusOr<Array*> ComputeArrayForDot(const Shape& shape, 279 const DotDimensionNumbers& dim_numbers, 280 const PrecisionConfig& precision_config, 281 Array* lhs, Array* rhs); 282 283 // This tries to fold a ScalarIndexedArray which has another 284 // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a 285 // ScalarIndexedArray as indices. If `source` happened to be a 286 // ScalarIndexedConstantArray this can result in an expression that is more 287 // canonical. 288 // 289 // As an example, consider a gather operation, G0, gathering 7 elements from 290 // an array "Arr" of shape [100] resulting in an array of shape [7], and a 291 // second gather operation, G1, which gathers 3 elements out of the result of 292 // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 293 // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can 294 // instead rewrite G1 to gather directly from "Arr" with the three indices 295 // from I0 as per I1. In other words, we can rewrite: 296 // 297 // G0 = [Arr[i] for i in I0] 298 // G1 = [G0[i] for i in I1] 299 // 300 // into 301 // 302 // I2 = [I0[i] for i in I1] 303 // G1 = [Arr[i] for i in I2] 304 StatusOr<ScalarIndexedArray*> FoldGatherOfGather( 305 ScalarIndexedArray* source, Array* indices, int64 source_dim, 306 absl::Span<const int64> output_dims, Shape shape); 307 308 // Reshapes a scalar-indexed node to remove the degenerate dimensions in its 309 // output. The result is always a scalar-indexed node. 310 StatusOr<ScalarIndexedArray*> ReshapeToRemoveDegenerateDims( 311 ScalarIndexedArray* operand); 312 313 // Reshapes a scalar-indexed node such that the result has the degenerate 314 // dimensions `degenerate_dims`. The result is always a scalar-indexed node. 315 StatusOr<ScalarIndexedArray*> ReshapeToAddDegenerateDims( 316 ScalarIndexedArray* operand, absl::Span<const int64> degenerate_dims); 317 318 StatusOr<ScalarIndexedArray*> FoldReshapeOfGather( 319 const Shape& shape, ScalarIndexedConstantArray* operand); 320 StatusOr<ScalarIndexedArray*> FoldReshapeOfGatherNoDegenerateDims( 321 const Shape& shape, ScalarIndexedConstantArray* scalar_indexed); 322 StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand); 323 324 StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, 325 Array* lhs, Array* rhs); 326 StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, 327 Array* operand); 328 329 template <typename T, typename... Args> Construct(Args &&...args)330 T* Construct(Args&&... args) { 331 T* new_tensor = new T(std::forward<Args>(args)...); 332 owned_tensors_.push_back(std::unique_ptr<T>(new_tensor)); 333 return new_tensor; 334 } 335 ConstructScalarIndexedArray(Array * source,Array * indices,int64 source_dim,std::vector<int64> output_dims,Shape shape)336 ScalarIndexedArray* ConstructScalarIndexedArray( 337 Array* source, Array* indices, int64 source_dim, 338 std::vector<int64> output_dims, Shape shape) { 339 if (source->kind() == Array::kConstant) { 340 return Construct<ScalarIndexedConstantArray>(source, indices, source_dim, 341 std::move(output_dims), 342 std::move(shape)); 343 } else { 344 return Construct<ScalarIndexedArray>(source, indices, source_dim, 345 std::move(output_dims), 346 std::move(shape)); 347 } 348 } 349 TakeOwnership(Literal literal)350 Literal* TakeOwnership(Literal literal) { 351 owned_literals_.push_back(std::move(literal)); 352 return &owned_literals_.back(); 353 } 354 TakeOwnership(StatusOr<Literal> literal_or_error)355 StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) { 356 TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); 357 owned_literals_.push_back(std::move(literal)); 358 return &owned_literals_.back(); 359 } 360 361 std::vector<std::unique_ptr<Array>> owned_tensors_; 362 std::vector<Literal> owned_literals_; 363 absl::flat_hash_map<const HloInstruction*, Array*> cache_; 364 }; 365 366 // A pass that prints all non-trivial results returned by IndexedArrayAnalysis. 367 // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to 368 // unconditionally add to the regular HLO pass pipeline. 369 class IndexedArrayAnalysisPrinterPass : public HloModulePass { 370 public: 371 absl::string_view name() const override; 372 StatusOr<bool> Run(HloModule* module) override; 373 }; 374 375 } // namespace xla 376 377 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ 378