1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // HLO instructions are in DAG form and represent the computations that the user 17 // has built up via the XLA service interface. They are ultimately lowered 18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered 19 // form; e.g. see DfsHloVisitor. 20 21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 23 24 #include <functional> 25 #include <iosfwd> 26 #include <list> 27 #include <memory> 28 #include <set> 29 #include <string> 30 #include <tuple> 31 #include <unordered_map> 32 #include <unordered_set> 33 #include <vector> 34 35 #include "tensorflow/compiler/xla/iterator_util.h" 36 #include "tensorflow/compiler/xla/literal_util.h" 37 #include "tensorflow/compiler/xla/map_util.h" 38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 39 #include "tensorflow/compiler/xla/service/hlo.pb.h" 40 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 41 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 42 #include "tensorflow/compiler/xla/service/name_uniquer.h" 43 #include "tensorflow/compiler/xla/types.h" 44 #include "tensorflow/compiler/xla/xla_data.pb.h" 45 #include "tensorflow/core/lib/core/status.h" 46 #include "tensorflow/core/lib/core/stringpiece.h" 47 #include "tensorflow/core/lib/gtl/array_slice.h" 48 #include "tensorflow/core/lib/gtl/flatmap.h" 49 #include "tensorflow/core/lib/gtl/inlined_vector.h" 50 #include "tensorflow/core/lib/gtl/iterator_range.h" 51 #include "tensorflow/core/platform/logging.h" 52 #include "tensorflow/core/platform/macros.h" 53 #include "tensorflow/core/platform/types.h" 54 55 namespace xla { 56 57 class HloComputation; 58 class HloModule; 59 60 // A bunch of switches that control how the hlo text should be printed. 61 class HloPrintOptions { 62 public: 63 // Constructs the default print options: don't print large constants, don't 64 // compact operands, no indentation. HloPrintOptions()65 HloPrintOptions() 66 : print_large_constants_(false), 67 print_subcomputation_references_(true), 68 print_metadata_(true), 69 compact_operands_(false), 70 print_operand_shape_(true), 71 print_program_shape_(true), 72 print_percent_(true), 73 indent_amount_(0) {} 74 ShortParsable()75 static HloPrintOptions ShortParsable() { 76 return HloPrintOptions() 77 .set_print_large_constants(true) 78 .set_print_subcomputation_references(true) 79 .set_print_metadata(false) 80 .set_print_operand_shape(false) 81 .set_print_program_shape(false) 82 .set_print_percent(false); 83 } 84 85 // If true, large constants will be printed out. set_print_large_constants(bool value)86 HloPrintOptions& set_print_large_constants(bool value) { 87 print_large_constants_ = value; 88 return *this; 89 } 90 91 // If true, the names of subcomputations (e.g. a fusion node's fused 92 // computation) won't be printed. This makes the resulting text not parsable. 93 // 94 // A CustomCall's call target is printed even if 95 // print_subcomputation_references is false, because the call target isn't an 96 // HloComputation. set_print_subcomputation_references(bool value)97 HloPrintOptions& set_print_subcomputation_references(bool value) { 98 print_subcomputation_references_ = value; 99 return *this; 100 } 101 102 // If true, metatdata will be printed. set_print_metadata(bool value)103 HloPrintOptions& set_print_metadata(bool value) { 104 print_metadata_ = value; 105 return *this; 106 } 107 108 // If true, operands' shapes will be printed. set_print_operand_shape(bool value)109 HloPrintOptions& set_print_operand_shape(bool value) { 110 print_operand_shape_ = value; 111 return *this; 112 } 113 114 // If true, program shape of hlo computations will be printed. set_print_program_shape(bool value)115 HloPrintOptions& set_print_program_shape(bool value) { 116 print_program_shape_ = value; 117 return *this; 118 } 119 120 // If true, names will be printed with prefix '%'. set_print_percent(bool value)121 HloPrintOptions& set_print_percent(bool value) { 122 print_percent_ = value; 123 return *this; 124 } 125 126 // If true, only a part of operands will be printed out, and their names will 127 // be omitted (note that in this case the text will not be parsable). set_compact_operands(bool value)128 HloPrintOptions& set_compact_operands(bool value) { 129 compact_operands_ = value; 130 return *this; 131 } 132 133 // The indent of the hlo text block. set_indent_amount(int value)134 HloPrintOptions& set_indent_amount(int value) { 135 indent_amount_ = value; 136 return *this; 137 } 138 print_large_constants()139 bool print_large_constants() const { return print_large_constants_; } print_subcomputation_references()140 bool print_subcomputation_references() const { 141 return print_subcomputation_references_; 142 } print_metadata()143 bool print_metadata() const { return print_metadata_; } compact_operands()144 bool compact_operands() const { return compact_operands_; } print_operand_shape()145 bool print_operand_shape() const { return print_operand_shape_; } print_program_shape()146 bool print_program_shape() const { return print_program_shape_; } print_percent()147 bool print_percent() const { return print_percent_; } indent_amount()148 int indent_amount() const { return indent_amount_; } 149 150 private: 151 bool print_large_constants_; 152 bool print_subcomputation_references_; 153 bool print_metadata_; 154 bool compact_operands_; 155 bool print_operand_shape_; 156 bool print_program_shape_; 157 bool print_percent_; 158 int indent_amount_; 159 }; 160 161 // HLO instructions are the IR used by the high-level compiler. 162 class HloInstruction { 163 public: 164 enum class FusionKind { 165 kLoop, // Fused into a loop. 166 kInput, // Op's input is fused into the op itself. 167 kOutput, // Op's output is fused into the op itself. 168 // REQUIRES: At least one operand buffer must be able 169 // to alias the output buffer. 170 kTransposeDot, // Fused into a dot with transposed operands. 171 kCustom, // Custom category for backend-specific fusions that 172 // do not match any of the more specific ones. 173 }; 174 175 ~HloInstruction(); 176 177 // Creates an instruction from the given proto. Arguments: 178 // 179 // module: the module which will contain the instruction. The newly created 180 // instruction is *not* added to the module or any computation, however. 181 // proto: the proto to convert from. 182 // instruction_map: a map from instruction name to HloInstruction*. This map 183 // must contain all operands of the newly constructed instruction. 184 // computation_map: a map from computation name to HloComputation*. This map 185 // must contain all computations which the newly constructed instruction 186 // calls. 187 // add_fused_computation: A function to call to add a fused 188 // computation. Used (clearly) when the instruction is a fusion 189 // instruction. 190 static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( 191 HloModule* module, const HloInstructionProto& proto, 192 const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map, 193 const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map, 194 const std::function<void(std::unique_ptr<HloComputation>)>& 195 add_fused_computation); 196 197 // Creates a parameter-retrieving instruction. 198 static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number, 199 const Shape& shape, 200 const string& name); 201 202 // Creates a literal constant instruction. 203 static std::unique_ptr<HloInstruction> CreateConstant( 204 std::unique_ptr<Literal> literal); 205 206 // Creates a get tuple element instruction. 207 static std::unique_ptr<HloInstruction> CreateGetTupleElement( 208 const Shape& shape, HloInstruction* operand, int64 index); 209 210 // Creates a trace instruction that logs the input operand in the computation. 211 static std::unique_ptr<HloInstruction> CreateTrace(const string& tag, 212 HloInstruction* operand); 213 214 // Creates a random number generation instruction that fills a shape with 215 // random numbers from a given distribution. 216 static std::unique_ptr<HloInstruction> CreateRng( 217 const Shape& shape, RandomDistribution distribution, 218 tensorflow::gtl::ArraySlice<HloInstruction*> parameters); 219 220 // Creates a unary instruction (one operand). 221 // Precondition: opcode must be a legitimate unary operation. 222 static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape, 223 HloOpcode opcode, 224 HloInstruction* operand); 225 226 // Creates a binary instruction (two operands). 227 // Precondition: opcode must be a legitimate binary operation. 228 static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape, 229 HloOpcode opcode, 230 HloInstruction* lhs, 231 HloInstruction* rhs); 232 233 // Creates a ternary instruction (three operands). 234 // Precondition: opcode must be a legitimate ternary operation. 235 static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape, 236 HloOpcode opcode, 237 HloInstruction* lhs, 238 HloInstruction* rhs, 239 HloInstruction* ehs); 240 241 // Creates a variadic instruction (variable number of operands). 242 // Precondition: opcode must be a legitimate variadic operation. 243 static std::unique_ptr<HloInstruction> CreateVariadic( 244 const Shape& shape, HloOpcode opcode, 245 tensorflow::gtl::ArraySlice<HloInstruction*> operands); 246 247 // Creates a map instruction, where the computation (given by the handle) is 248 // applied element-wise to every element in operands (across the operands, 249 // at a given index) with the same `static_operands`. 250 static std::unique_ptr<HloInstruction> CreateMap( 251 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, 252 HloComputation* map_computation, 253 tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {}); 254 255 // Creates a convolution op, where rhs is the convolutional filter 256 // and window describes how the filter is applied to lhs. 257 static std::unique_ptr<HloInstruction> CreateConvolve( 258 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 259 const Window& window, 260 const ConvolutionDimensionNumbers& dimension_numbers); 261 262 // Creates an FFT op, of the type indicated by fft_type. 263 static std::unique_ptr<HloInstruction> CreateFft( 264 const Shape& shape, HloInstruction* operand, FftType fft_type, 265 tensorflow::gtl::ArraySlice<int64> fft_length); 266 267 // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch 268 // dimensions specified in 'dimension_numbers'. 269 static std::unique_ptr<HloInstruction> CreateDot( 270 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, 271 const DotDimensionNumbers& dimension_numbers); 272 273 // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 274 // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS 275 // and the RHS must be of rank 2. 276 static std::unique_ptr<HloInstruction> CreateCanonicalDot( 277 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); 278 279 // Creates a reduce-precision op, where operand is the data to reduce in 280 // precision, and exponent_bits and mantissa_bits describe the precision to 281 // reduce it to. 282 static std::unique_ptr<HloInstruction> CreateReducePrecision( 283 const Shape& shape, HloInstruction* operand, const int exponent_bits, 284 const int mantissa_bits); 285 286 // Creates a cross replica sum op. 287 static std::unique_ptr<HloInstruction> CreateCrossReplicaSum( 288 const Shape& shape, 289 tensorflow::gtl::ArraySlice<HloInstruction*> operands); 290 291 // Creates a conversion instruction, where operand is the data to convert and 292 // shape is the target shape for the conversion. 293 static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape, 294 HloInstruction* operand); 295 296 // Creates a bitcast conversion instruction, where operand is the data to 297 // convert and shape is the target shape for the conversion. 298 static std::unique_ptr<HloInstruction> CreateBitcastConvert( 299 const Shape& shape, HloInstruction* operand); 300 301 // Creates an infeed instruction, which reads data of the given shape from the 302 // Infeed interface of the device. 303 static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape, 304 const string& config); 305 306 // Creates an outfeed instruction, which outputs data. 307 static std::unique_ptr<HloInstruction> CreateOutfeed( 308 const Shape& shape, HloInstruction* operand, 309 tensorflow::StringPiece outfeed_config); 310 311 // Creates an asynchronous send instruction with the given channel id, which 312 // initiates sending the operand data to a unique receive instruction in 313 // another computation that has the same channel id. 314 static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand, 315 int64 channel_id); 316 317 // Blocks until data transfer for the Send instruction (operand) is complete. 318 // The operand must be kSend. 319 static std::unique_ptr<HloInstruction> CreateSendDone( 320 HloInstruction* operand); 321 322 // Creates an asynchronous receive instruction with the given channel id, 323 // which allocates resources to receive data of the given shape from a unique 324 // send instruction in another computation that has the same channel id. 325 static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape, 326 int64 channel_id); 327 328 // Blocks until data transfer for the Recv instruction (operand) is complete 329 // and returns the receive buffer. The operand must be kRecv. 330 static std::unique_ptr<HloInstruction> CreateRecvDone( 331 HloInstruction* operand); 332 333 // Creates a slice instruction, where the operand is sliced by the given 334 // start/limit indices. 335 static std::unique_ptr<HloInstruction> CreateSlice( 336 const Shape& shape, HloInstruction* operand, 337 tensorflow::gtl::ArraySlice<int64> start_indices, 338 tensorflow::gtl::ArraySlice<int64> limit_indices, 339 tensorflow::gtl::ArraySlice<int64> strides); 340 341 // Creates a slice instruction, where the first operand is sliced by 342 // start indices specified in the second operand, and by size specified in 343 // 'slice_sizes'. 344 static std::unique_ptr<HloInstruction> CreateDynamicSlice( 345 const Shape& shape, HloInstruction* operand, 346 HloInstruction* start_indices, 347 tensorflow::gtl::ArraySlice<int64> slice_sizes); 348 349 // Creates a dynamic update slice instruction, which updates a slice 350 // of 'operand' with 'update' and 'start_indices'. 351 static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice( 352 const Shape& shape, HloInstruction* operand, HloInstruction* update, 353 HloInstruction* start_indices); 354 355 // Creates a concatenate instruction, where the operands are concatenated on 356 // the provided dimension. 357 static std::unique_ptr<HloInstruction> CreateConcatenate( 358 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, 359 int64 dimension); 360 361 // Creates a reduce instruction, where the computation (given by the handle) 362 // is applied successively to every element in operand. That is, if f is the 363 // function to apply (which either takes 2 [accumulator, value] or 3 364 // [accumulator, index, value] arguments) and init is a reduction operator 365 // specified initial value (for example, 0 for addition), then this operation 366 // will compute: 367 // f(f(init, [index0], value0), [index1], value1), ...) 368 static std::unique_ptr<HloInstruction> CreateReduce( 369 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 370 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, 371 HloComputation* reduce_computation); 372 373 // Creates a reduce-window instruction, where the computation (given 374 // by the handle) is applied window-wise at each valid window 375 // position in the operand. 376 static std::unique_ptr<HloInstruction> CreateReduceWindow( 377 const Shape& shape, HloInstruction* operand, HloInstruction* init_value, 378 const Window& window, HloComputation* reduce_computation); 379 380 // Creates a batch-norm-training instruction. 381 static std::unique_ptr<HloInstruction> CreateBatchNormTraining( 382 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 383 HloInstruction* offset, float epsilon, int64 feature_index); 384 385 // Creates a batch-norm-inference instruction. 386 static std::unique_ptr<HloInstruction> CreateBatchNormInference( 387 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 388 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, 389 float epsilon, int64 feature_index); 390 391 // Creates a batch-norm-grad instruction. 392 static std::unique_ptr<HloInstruction> CreateBatchNormGrad( 393 const Shape& shape, HloInstruction* operand, HloInstruction* scale, 394 HloInstruction* mean, HloInstruction* variance, 395 HloInstruction* grad_output, float epsilon, int64 feature_index); 396 397 // Creates a scatter computation that scatters the `source` array to the 398 // selected indices of each window. 399 static std::unique_ptr<HloInstruction> CreateSelectAndScatter( 400 const Shape& shape, HloInstruction* operand, HloComputation* select, 401 const Window& window, HloInstruction* source, HloInstruction* init_value, 402 HloComputation* scatter); 403 404 // Creates a broadcast instruction. 405 static std::unique_ptr<HloInstruction> CreateBroadcast( 406 const Shape& shape, HloInstruction* operand, 407 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions); 408 409 // Creates a sequence of instructions that performs an explicit broadcast of 410 // the operand to the target shape. 411 // 412 // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is 413 // returned as a unique_ptr for API consistency with other factory methods in 414 // this interface. 415 // 416 // TODO(b/72173833) Ideally HloComputations would always be present, and so 417 // the adder being passed by the caller would not be necessary. 418 static std::unique_ptr<HloInstruction> CreateBroadcastSequence( 419 const Shape& output_shape, HloInstruction* operand, 420 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>& 421 adder); 422 423 // Creates a pad instruction, where the operand is padded on the edges and 424 // between the elements with the given padding value. 425 static std::unique_ptr<HloInstruction> CreatePad( 426 const Shape& shape, HloInstruction* operand, 427 HloInstruction* padding_value, const PaddingConfig& padding_config); 428 429 // Creates a reshape instruction, where the operand is flattened row-major 430 // order and then reshaped to the given result shape. 431 static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape, 432 HloInstruction* operand); 433 434 // Creates a transpose instruction which permutes the operand dimensions. 435 static std::unique_ptr<HloInstruction> CreateTranspose( 436 const Shape& shape, HloInstruction* operand, 437 tensorflow::gtl::ArraySlice<int64> dimensions); 438 439 // Creates a while instruction, given a condition computation, a body 440 // computation, and the initial value for the input of the computations. For 441 // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1 442 // corresponds to the C code below. 443 // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 } 444 static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape, 445 HloComputation* condition, 446 HloComputation* body, 447 HloInstruction* init); 448 449 static std::unique_ptr<HloInstruction> CreateConditional( 450 const Shape& shape, HloInstruction* pred, 451 HloInstruction* true_computation_arg, HloComputation* true_computation, 452 HloInstruction* false_computation_arg, HloComputation* false_computation); 453 454 static std::unique_ptr<HloInstruction> CreateGather( 455 const Shape& shape, HloInstruction* operand, 456 HloInstruction* gather_indices, 457 const GatherDimensionNumbers& gather_dim_numbers, 458 tensorflow::gtl::ArraySlice<int64> window_bounds); 459 460 // Creates a fusion instruction. A fusion instruction contains one or more 461 // fused instructions forming an expression with a single root 462 // "fused_root". Additional instructions can be added to the fusion 463 // instruction with the method FuseInstruction. 464 static std::unique_ptr<HloInstruction> CreateFusion( 465 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); 466 467 static std::unique_ptr<HloInstruction> CreateFusion( 468 const Shape& shape, FusionKind fusion_kind, 469 tensorflow::gtl::ArraySlice<HloInstruction*> operands, 470 HloComputation* fusion_computation); 471 472 // Creates a call instruction that applies the given computation on the given 473 // operands. "shape" is the resultant shape. 474 static std::unique_ptr<HloInstruction> CreateCall( 475 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, 476 HloComputation* computation); 477 478 // Creates a custom call instruction that applies the given custom call target 479 // to the given operands. "shape" is the resultant shape. 480 static std::unique_ptr<HloInstruction> CreateCustomCall( 481 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, 482 tensorflow::StringPiece custom_call_target); 483 484 // Creates a HostCompute instruction, which records host-side control and 485 // data dependencies for use in instruction scheduling. 486 static std::unique_ptr<HloInstruction> CreateHostCompute( 487 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, 488 tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); 489 490 // Creates a tuple instruction with the given elements. This is a convenience 491 // wrapper around CreateVariadic. 492 static std::unique_ptr<HloInstruction> CreateTuple( 493 tensorflow::gtl::ArraySlice<HloInstruction*> elements); 494 495 // Creates a reverse instruction, which reverses the order of the elements 496 // in the specified dimensions. 497 static std::unique_ptr<HloInstruction> CreateReverse( 498 const Shape& shape, HloInstruction* operand, 499 tensorflow::gtl::ArraySlice<int64> dimensions); 500 501 // Creates an instance of GatherDimensionNumbers. 502 static GatherDimensionNumbers MakeGatherDimNumbers( 503 tensorflow::gtl::ArraySlice<int64> output_window_dims, 504 tensorflow::gtl::ArraySlice<int64> elided_window_dims, 505 tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims); 506 507 // Returns the opcode for this instruction. opcode()508 HloOpcode opcode() const { return opcode_; } 509 510 // Returns true if this instruction has a side effect. An instruction has a 511 // side effect if it uses certain opcodes or calls a computation with a side 512 // effect. 513 bool HasSideEffect() const; 514 515 // Returns the result shape of this instruction. 516 const Shape& shape() const; 517 518 // Returns the (mutable) result shape of this instruction. mutable_shape()519 Shape* mutable_shape() { return &shape_; } 520 521 // Returns the ith operand to this instruction. 522 const HloInstruction* operand(int64 i) const; 523 524 // Returns the ith operand to this instruction. 525 HloInstruction* mutable_operand(int64 i); 526 527 // Returns the number of operands to this instruction. operand_count()528 int64 operand_count() const { return operands_.size(); } 529 530 // Returns the vector of operands of this instruction. 531 using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>; operands()532 const InstructionVector& operands() const { return operands_; } 533 534 // Returns the index of 'target' in the operands sequence. 535 // Precondition: target must be an operand (or a fatal error will occur). 536 int64 operand_index(const HloInstruction* target) const; 537 538 // Returns the number of users of this instruction. user_count()539 int64 user_count() const { return users_.size(); } 540 541 // Returns the users of this instruction. users()542 const std::vector<HloInstruction*>& users() const { return users_; } 543 544 // Returns true if this instruction is a user of 'instruction'. IsUserOf(const HloInstruction * instruction)545 bool IsUserOf(const HloInstruction* instruction) const { 546 return ContainsKey(instruction->user_set_, this); 547 } 548 549 // Adds a control dependency from this instruction to the given 550 // instruction. This instruction becomes a control predecessor of 551 // 'instruction', and 'instruction' becomes a control successor of this 552 // instruction. Returns an error status if either of the given instructions 553 // does not belong to the same computation. 554 // 555 // This is used to enforce an additional ordering requirement that is not 556 // captured by normal data dependencies, such as ordering among Send or Recv 557 // operations to avoid deadlock. 558 Status AddControlDependencyTo(HloInstruction* instruction); 559 560 // Removes a previously added control dependency from this instruction to 561 // 'instruction'. 562 Status RemoveControlDependencyTo(HloInstruction* instruction); 563 564 // Returns the set of control predecessors (successors) of this 565 // instruction. Control predecessors (successors) must execute before (after) 566 // the current instruction. control_predecessors()567 const std::vector<HloInstruction*>& control_predecessors() const { 568 return control_predecessors_; 569 } control_successors()570 const std::vector<HloInstruction*>& control_successors() const { 571 return control_successors_; 572 } 573 574 // Returns true if "other" performs the same computation as this instruction. 575 bool Identical( 576 const HloInstruction& other, 577 const std::function<bool(const HloInstruction*, const HloInstruction*)>& 578 eq_operands = std::equal_to<const HloInstruction*>(), 579 const std::function<bool(const HloComputation*, const HloComputation*)>& 580 eq_computations = std::equal_to<const HloComputation*>(), 581 bool layout_sensitive = true) const { 582 // An instruction is always identical to itself. 583 if (this == &other) { 584 return true; 585 } 586 587 // Identical instruction must have the same opcode, shape, and identical 588 // operands. 589 if (opcode() != other.opcode()) { 590 return false; 591 } 592 using EqShapeFuncType = bool (*)(const Shape&, const Shape&); 593 EqShapeFuncType eq_shapes = 594 layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; 595 if (!eq_shapes(shape(), other.shape())) { 596 return false; 597 } 598 if (operands().size() != other.operands().size()) { 599 return false; 600 } 601 602 // Use an explicit loop rather than ContainerEquals, because copying around 603 // std::functions may be too expensive in some cases. 604 for (size_t i = 0; i < operands().size(); ++i) { 605 if (!eq_operands(operand(i), other.operand(i))) { 606 return false; 607 } 608 } 609 610 return IdenticalSlowPath(other, eq_computations, eq_shapes); 611 } 612 613 // Returns whether the instruction has a constant operand. 614 bool HasConstantOperand() const; 615 616 // Returns whether this instruction does a rank-2 transposition. 617 bool IsRank2Transpose() const; 618 619 // Replaces the use of this instruction in "user" with "new_producer". Note 620 // that there might be multiple uses of this instruction in "user"; all will 621 // be replaced. 622 Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); 623 624 // Replaces the specified operand with new_operand. 625 Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); 626 627 // Replaces all uses of this instruction with the new producer. If 628 // new_producer is a user of this instruction then new_producer remains a use 629 // of this instruction to avoid introducing cycles into the graph. 630 // 631 // If this instruction is the root of its computation, sets the computation's 632 // root to new_producer. 633 Status ReplaceAllUsesWith(HloInstruction* new_producer); 634 635 // Detaches an instruction from its operands. That is, remove the instruction 636 // from each operand's user set. This should only be called prior to 637 // deallocating the instruction. 638 void DetachFromOperands(); 639 640 // Performs a postorder DFS visit using this node as the root. If 641 // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when 642 // complete. If ignore_control_predecessors is true, instructions only 643 // reachable via control dependencies will not be visited, and the postorder 644 // will not take control dependencies into account. It is as if the control 645 // dependencies didn't exist in the graph at all. 646 template <typename HloInstructionPtr> 647 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor, 648 bool call_finish_visit = true, 649 bool ignore_control_predecessors = false); 650 Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true, 651 bool ignore_control_predecessors = false) const { 652 return const_cast<HloInstruction*>(this)->Accept( 653 visitor, call_finish_visit, ignore_control_predecessors); 654 } 655 656 // Same as Accept() above, but the order of operand and control predecessor 657 // visitation is determined by the given operand order; if compare(A, B) == 658 // true, A is visited before B. 659 using CompareFunction = 660 std::function<bool(const HloInstruction*, const HloInstruction*)>; 661 Status AcceptWithOperandOrder(DfsHloVisitor* visitor, 662 const CompareFunction& operand_order, 663 bool call_finish_visit = true); 664 665 // Performs a postorder DFS visit using this node as the root. Calls the given 666 // visitor function at each instruction. 667 Status Accept(const std::function<Status(HloInstruction*)>& visitor_func); 668 Status Accept( 669 const std::function<Status(const HloInstruction*)>& visitor_func) const; 670 671 // Visits all instructions rooted at this instruction using the given visitor 672 // in the given order. 'order' must contain at least the set of instructions 673 // rooted at this node (ie, those accessible from a DFS traversal from this 674 // instruction). Instructions contained in 'order' which are not in the set of 675 // instructions rooted at this node are ignored. 'order' must also be a valid 676 // topological sort of these instructions (defs appear before uses) though 677 // need not be a DFS post-order. 678 Status AcceptOrdered(DfsHloVisitor* visitor, 679 const std::vector<const HloInstruction*>& order); 680 681 // Visit this instruction and only this instruction with the given visitor. 682 template <typename HloInstructionPtr> 683 Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor); 684 685 // Returns the literal associated with this instruction. 686 // 687 // Note: only constant and parameter opcodes have an associated literal. 688 const Literal& literal() const; 689 690 // Returns the parameter number associated with this instruction. 691 // 692 // Note: only parameter opcodes have an associated parameter number. parameter_number()693 int64 parameter_number() const { 694 CHECK_EQ(HloOpcode::kParameter, opcode_); 695 return parameter_number_; 696 } 697 698 // Returns the dimension sizes or numbers associated with this instruction. 699 // 700 // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, 701 // and reverse. 702 const std::vector<int64>& dimensions() const; 703 int64 dimensions(int64 index) const; 704 705 // Accessor for the dimension in which a concatenate HLO should occur. 706 // Precondition: opcode() == HloOpcode::kConcatenate 707 int64 concatenate_dimension() const; 708 709 // Returns the tuple index associated with this instruction. 710 // 711 // Precondition: opcode() == HloOpcode::kGetTupleElement 712 int64 tuple_index() const; 713 714 // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. 715 // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the 716 // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. 717 std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() 718 const; 719 LatestNonGteAncestorAndIndex()720 std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() { 721 auto rv = 722 const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex(); 723 return {const_cast<HloInstruction*>(rv.first), rv.second}; 724 } 725 726 // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. 727 const HloInstruction* LatestNonGteAncestor() const; 728 LatestNonGteAncestor()729 HloInstruction* LatestNonGteAncestor() { 730 return const_cast<HloInstruction*>( 731 const_cast<const HloInstruction*>(this)->LatestNonGteAncestor()); 732 } 733 734 // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. 735 // The setter should only be called by HloModule or HloComputation methods. 736 // 737 // Precondition: The instruction has a valid to_apply_ field. 738 HloComputation* to_apply() const; 739 void set_to_apply(HloComputation* to_apply); 740 741 // Returns the custom_call_target for CustomCall. 742 // Precondition: opcode() == HloOpcode::kCustomCall 743 const string& custom_call_target() const; 744 745 // Returns the config for the Outfeed instruction. 746 // Precondition: opcode() == HloOpcode::kOutfeed 747 const string& outfeed_config() const; 748 749 // Returns the shape for the Outfeed instruction. 750 // Precondition: opcode() == HloOpcode::kOutfeed 751 const Shape& outfeed_shape() const; 752 753 // Gets/sets the while_condition or while_body HloComputation for While. The 754 // setters should only be called by HloModule or HloComputation methods. 755 // 756 // Precondition: The instruction is a While instruction. 757 HloComputation* while_condition() const; 758 HloComputation* while_body() const; 759 void set_while_condition(HloComputation* while_condition); 760 void set_while_body(HloComputation* while_body); 761 762 // Gets/sets the select or scatter HloComputation for SelectAndScatter. The 763 // setters should only be called by HloModule or HloComputation methods. 764 // 765 // Precondition: opcode() == HloOpcode::kSelectAndScatter. 766 HloComputation* select() const; 767 HloComputation* scatter() const; 768 void set_select(HloComputation* select); 769 void set_scatter(HloComputation* scatter); 770 771 // Gets/sets the true and false HloComputation for Conditional. The setters 772 // should only be called by HloModule or HloComputation methods. 773 // 774 // Precondition: The instruction is a Conditional instruction. 775 HloComputation* true_computation() const; 776 HloComputation* false_computation() const; 777 void set_true_computation(HloComputation* true_computation); 778 void set_false_computation(HloComputation* false_computation); 779 780 // Returns a string for the signature of this instruction if considered as a 781 // function, e.g. the signature of an F32 add is (F32, F32) -> F32. 782 string SignatureString() const; 783 784 // Returns a debugging string that represents this instruction. 785 // 786 // (We express the default options using an overload rather than a default 787 // param because gdb ignores default params, but does resolve overloads.) 788 // 789 // TODO(b/73348663): Make ToString() adaptive to the size of the string by 790 // default, backing off on providing full information for very large strings, 791 // or provide a different name for a ToString-like function that does that. ToString()792 string ToString() const { return ToString(HloPrintOptions()); } 793 string ToString(const HloPrintOptions& options) const; 794 795 // Components of the ToString() representation: 796 797 // Returns a string representation of the operand list. 798 string OperandsToString(const HloPrintOptions& options) const; 799 800 // Returns string representation of op-specific attributes. 801 std::vector<string> ExtraAttributesToString( 802 const HloPrintOptions& options) const; 803 804 // As ToString, but returns a shorter string. 805 string ToShortString() const; 806 807 // Returns a serialized representation of this instruction. 808 HloInstructionProto ToProto() const; 809 810 // Returns a category for the HLO. This could be something like "convolution" 811 // or "elementwise". 812 string ToCategory() const; 813 814 // Returns a logging instruction, if the output of this instruction is logged. 815 // 816 // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace 817 HloInstruction* tracing() const; 818 void set_tracing(HloInstruction* trace_instruction); 819 820 // Returns the channel id associated with the instruction. The id is 821 // shared between each Send/Recv pair and is globally unique to identify each 822 // channel. 823 // 824 // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv channel_id()825 int64 channel_id() const { return channel_id_; } 826 827 // Returns feature_index field associated with the instruction. The index 828 // represents the index of the feature dimension. 829 // 830 // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, 831 // or kBatchNormGrad. feature_index()832 int64 feature_index() const { return feature_index_; } 833 834 // Returns a epsilon value associated with the instruction. The is a small 835 // number added to the variance to avoid divide-by-zero error. 836 // 837 // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, 838 // or kBatchNormGrad. epsilon()839 float epsilon() const { return epsilon_; } 840 841 // Returns the infeed configuration string. The infeed configuration includes 842 // any metadata needed for the backend compiler (e.g., infeed buffer address) 843 // and is target-dependent. infeed_config()844 string infeed_config() const { return infeed_config_; } set_infeed_config(const string & config)845 void set_infeed_config(const string& config) { infeed_config_ = config; } 846 847 // Returns a tag to be used in tracing. 848 // 849 // Precondition: opcode() == HloOpcode::kTrace 850 string TracingTag() const; 851 852 // Returns whether the instruction is a constant. 853 bool IsConstant() const; 854 855 // Returns true if this instruction is fused, ie contained within a fusion 856 // instruction. 857 bool IsFused() const; 858 859 // Returns the computation for this fused instruction. 860 // 861 // Precondition: opcode() == HloOpcode::kFusion 862 HloComputation* fused_instructions_computation() const; 863 864 // Returns true if this instruction can be legally fused into a fusion 865 // instruction. 866 bool IsFusable() const; 867 868 // Returns the root instruction of the fused expression contained within this 869 // fusion instruction. 870 // 871 // Precondition: opcode() == HloOpcode::kFusion 872 HloInstruction* fused_expression_root() const; 873 874 // Returns the list of fused instructions inside this fusion instruction. The 875 // returned type is a range of HloInstruction*s. 876 // 877 // Precondition: opcode() == HloOpcode::kFusion 878 const tensorflow::gtl::iterator_range<UnwrappingIterator< 879 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> 880 fused_instructions() const; 881 882 const tensorflow::gtl::iterator_range< 883 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> 884 fused_instructions(); 885 886 // Gets the number of instructions inside this fusion instruction. 887 // 888 // Precondition: opcode() == HloOpcode::kFusion 889 int64 fused_instruction_count() const; 890 891 // Returns the fused parameter instruction in this fusion instruction 892 // corresponding to the given parameter number. 893 // 894 // Precondition: opcode() == HloOpcode::kFusion 895 HloInstruction* fused_parameter(int64 parameter_number) const; 896 897 // Returns the vector of fused parameters inside this fusion instruction. 898 // 899 // Precondition: opcode() == HloOpcode::kFusion 900 const std::vector<HloInstruction*>& fused_parameters() const; 901 902 // Returns true if this instruction is a fusion instruction that generates 903 // multiple outputs. IsMultiOutputFusion()904 const bool IsMultiOutputFusion() const { 905 return opcode() == HloOpcode::kFusion && 906 fused_expression_root()->opcode() == HloOpcode::kTuple; 907 } 908 fusion_kind()909 FusionKind fusion_kind() const { 910 CHECK_EQ(HloOpcode::kFusion, opcode_); 911 return fusion_kind_; 912 } 913 set_fusion_kind(FusionKind kind)914 void set_fusion_kind(FusionKind kind) { 915 CHECK_EQ(HloOpcode::kFusion, opcode_); 916 fusion_kind_ = kind; 917 } 918 919 // Returns the sharding applied to this operator. 920 // REQUIRES: has_sharding() is true. sharding()921 const HloSharding& sharding() const { 922 CHECK(has_sharding()); 923 return *sharding_; 924 } 925 // Returns the sharding applied to this operator, or default_ if none exists. sharding_or_default(const HloSharding & default_)926 const HloSharding& sharding_or_default(const HloSharding& default_) const { 927 return sharding_ ? *sharding_ : default_; 928 } 929 // Sets the sharding of this operator. Should only be called by HloModule or 930 // HloComputation methods. set_sharding(const HloSharding & sharding)931 void set_sharding(const HloSharding& sharding) { 932 sharding_ = MakeUnique<HloSharding>(sharding); 933 } 934 // Remove any sharding from this operator. clear_sharding()935 void clear_sharding() { sharding_ = nullptr; } 936 // Return true if this operator has a sharding assigned. has_sharding()937 bool has_sharding() const { return sharding_ != nullptr; } 938 939 // Adds a new operand the fusion instruction. 940 HloInstruction* AddFusionOperand(HloInstruction* new_operand); 941 942 // Merges the fused instructions from 'instruction_to_merge' into the 943 // fused instruction set of 'this', updating operands as necessary. 944 // 945 // Precondition: opcode() == HloOpcode::kFusion 946 // Predondition: 'instruction_to_merge' must be an operand of 'this'. 947 void MergeFusionInstruction(HloInstruction* instruction_to_merge); 948 949 // Merges the fused instructions from instruction_to_merge into the fused 950 // instruction set of 'this' and generates multioutput fusion instructions. 951 // All the users of instruction_to_merge will be redirected to 'this' 952 // instruction. instruction_to_merge will be removed from its parent 953 // computation. 954 // 955 // Precondition: opcode() == HloOpcode::kFusion 956 void MergeFusionInstructionIntoMultiOutput( 957 HloInstruction* instruction_to_merge); 958 959 // Fuses the given instruction in this fusion instruction. instruction_to_fuse 960 // is cloned and the clone is placed in the fusion 961 // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather 962 // than moved to cleanly handle the case where the instruction has a use 963 // outside the fusion instruction. Moving such an instruction into a fusion 964 // instruction would violate the single-result invariant of HLO instructions 965 // and significantly complicate code generation. 966 // 967 // Precondition: this->opcode() == HloOpcode::kFusion FuseInstruction(HloInstruction * instruction_to_fuse)968 HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { 969 return FuseInstructionInternal(instruction_to_fuse); 970 } 971 972 // Fuses the given instruction in this fusion instruction and generate 973 // multioutput fusion instruction. A clone of the instruction_to_fuse will 974 // be part of the output of fusion instructions. The users of 975 // instruction_to_fuse will be redirected to this fusion instructions. 976 // instruction_to_fuse will be removed from its parent computation. 977 // 978 // Precondition: this->opcode() == HloOpcode::kFusion FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)979 HloInstruction* FuseInstructionIntoMultiOutput( 980 HloInstruction* instruction_to_fuse) { 981 return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); 982 } 983 984 // Returns the start index in the given dimension for a slice node. 985 // 986 // Precondition: opcode() == HloOpcode::kSlice slice_starts(int64 dimension)987 int64 slice_starts(int64 dimension) const { 988 CHECK_EQ(HloOpcode::kSlice, opcode_); 989 return slice_starts_[dimension]; 990 } slice_starts()991 const std::vector<int64>& slice_starts() const { return slice_starts_; } 992 993 // Returns the (exclusive) limit index in the given dimension for a slice 994 // node. 995 // 996 // Precondition: opcode() == HloOpcode::kSlice slice_limits(int64 dimension)997 int64 slice_limits(int64 dimension) const { 998 CHECK_EQ(HloOpcode::kSlice, opcode_); 999 return slice_limits_[dimension]; 1000 } slice_limits()1001 const std::vector<int64>& slice_limits() const { 1002 CHECK_EQ(HloOpcode::kSlice, opcode_); 1003 return slice_limits_; 1004 } 1005 1006 // Returns the stride in the given dimension for a slice node. 1007 // 1008 // Precondition: opcode() == HloOpcode::kSlice slice_strides(int64 dimension)1009 int64 slice_strides(int64 dimension) const { 1010 CHECK_EQ(HloOpcode::kSlice, opcode_); 1011 return slice_strides_[dimension]; 1012 } slice_strides()1013 const std::vector<int64>& slice_strides() const { return slice_strides_; } 1014 1015 // Returns the flag that describes whether a slice must be lowered into an 1016 // offset into the original operand. IsInPlaceSlice()1017 bool IsInPlaceSlice() const { return is_in_place_slice_; } 1018 1019 // Sets and returns the flag that describes whether a slice must be lowered 1020 // into an offset into the original operand. SetIsInPlaceSlice(bool value)1021 bool SetIsInPlaceSlice(bool value) { 1022 is_in_place_slice_ = value; 1023 return value; 1024 } 1025 1026 // Returns the size of the slice in the given dimension for a dynamic 1027 // slice node. 1028 // 1029 // Precondition: opcode() == HloOpcode::kDynamicSlice slice_sizes(int64 dimension)1030 int64 slice_sizes(int64 dimension) const { 1031 CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); 1032 return dynamic_slice_sizes_[dimension]; 1033 } dynamic_slice_sizes()1034 const std::vector<int64>& dynamic_slice_sizes() const { 1035 CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); 1036 return dynamic_slice_sizes_; 1037 } 1038 1039 // Returns the number of exponent bits for a reduce-precision node. 1040 // 1041 // Precondition: opcode() == HloOpcode::kReducePrecision exponent_bits()1042 int32 exponent_bits() const { 1043 CHECK_EQ(HloOpcode::kReducePrecision, opcode_); 1044 return exponent_bits_; 1045 } 1046 1047 // Returns the number of mantissa bits for a reduce-precision node. 1048 // 1049 // Precondition: opcode() == HloOpcode::kReducePrecision mantissa_bits()1050 int32 mantissa_bits() const { 1051 CHECK_EQ(HloOpcode::kReducePrecision, opcode_); 1052 return mantissa_bits_; 1053 } 1054 1055 // Returns data on the window in a windowed operation such as 1056 // convolution. window()1057 const Window& window() const { 1058 CHECK(window_ != nullptr); 1059 return *window_; 1060 } 1061 1062 // Sets the window data in a windowed operation such as convolution. set_window(const Window & window)1063 void set_window(const Window& window) { 1064 window_ = MakeUnique<Window>(window); 1065 } 1066 1067 // Returns the padding configuration for a pad node. 1068 // 1069 // Precondition: opcode() == HloOpcode::kPad padding_config()1070 const PaddingConfig& padding_config() const { 1071 CHECK(padding_config_ != nullptr); 1072 return *padding_config_; 1073 } 1074 1075 // Returns data on the dimension numbers used for a convolution operation, 1076 // which may be a kConvolution instruction or a kCustomCall that implements a 1077 // convolution. convolution_dimension_numbers()1078 const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { 1079 CHECK(convolution_dimension_numbers_ != nullptr); 1080 return *convolution_dimension_numbers_; 1081 } 1082 1083 // Sets the convolution dimension numbers on this instruction. In general you 1084 // shouldn't need to call this; instead, specify the convolution dimension 1085 // numbers when you create the instruction. set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1086 void set_convolution_dimension_numbers( 1087 const ConvolutionDimensionNumbers& dnums) { 1088 convolution_dimension_numbers_ = 1089 MakeUnique<ConvolutionDimensionNumbers>(dnums); 1090 } 1091 fft_type()1092 FftType fft_type() const { 1093 CHECK_EQ(HloOpcode::kFft, opcode_); 1094 return fft_type_; 1095 } 1096 fft_length()1097 const std::vector<int64>& fft_length() const { 1098 CHECK_EQ(HloOpcode::kFft, opcode_); 1099 return fft_length_; 1100 } 1101 1102 // Returns the dump string of the convolution dimension numbers. 1103 string ConvolutionDimensionNumbersToString() const; 1104 1105 // Returns data on the dimension numbers used for a dot operation. dot_dimension_numbers()1106 const DotDimensionNumbers& dot_dimension_numbers() const { 1107 CHECK(dot_dimension_numbers_ != nullptr); 1108 return *dot_dimension_numbers_; 1109 } 1110 1111 // Returns the dump string of the dot dimension numbers. 1112 string DotDimensionNumbersToString() const; 1113 gather_dimension_numbers()1114 const GatherDimensionNumbers& gather_dimension_numbers() const { 1115 CHECK(gather_dimension_numbers_ != nullptr); 1116 return *gather_dimension_numbers_; 1117 } 1118 gather_window_bounds()1119 tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const { 1120 CHECK_EQ(opcode(), HloOpcode::kGather); 1121 return gather_window_bounds_; 1122 } 1123 1124 // Returns the dump string of the gather dimension numbers. 1125 string GatherDimensionNumbersToString() const; 1126 1127 // Returns the random distribution for this rng node. 1128 // 1129 // Precondition: opcode() == HloOpcode::kRng 1130 RandomDistribution random_distribution() const; 1131 1132 // Clones the HLO instruction. The clone will have the same opcode, shape, and 1133 // operands. After creation the clone has no uses. "this" (the instruction 1134 // cloned from) is not changed. Suffix is the string to append to the name of 1135 // the instruction to form the name of the cloned instruction. 1136 // If the module pointer is not nullptr, it will be the module where 1137 // the cloned computations will be added to (in order to support deep 1138 // cloning). 1139 std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone", 1140 HloModule* module = nullptr) const; 1141 1142 // Clones the HLO instruction as above but with new shape and operands. 1143 // If the module pointer is not nullptr, it will be the module where 1144 // the cloned computations will be added to (in order to support deep 1145 // cloning). 1146 std::unique_ptr<HloInstruction> CloneWithNewOperands( 1147 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, 1148 HloModule* module = nullptr) const; 1149 1150 // Returns the computations this instruction directly calls (if any). called_computations()1151 const std::vector<HloComputation*>& called_computations() const { 1152 return called_computations_; 1153 } 1154 1155 // Replaces all called computations based on a map function. This is needed 1156 // when we clone hlo_computations and want to let the instructions to point 1157 // to the newly cloned nodes. ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1158 void ReplaceCalledComputations( 1159 std::function<HloComputation*(HloComputation*)> map_function) { 1160 for (int64 i = 0; i < called_computations_.size(); ++i) { 1161 called_computations_[i] = map_function(called_computations_[i]); 1162 } 1163 } 1164 1165 // Clears out the called computations. 1166 // 1167 // This is, in particular, necessary when inlining function bodies into their 1168 // caller. If there were side-effecting operations in the called computations, 1169 // the call itself is considered side-effecting and thus cannot be removed. By 1170 // clearing out the computations, we reflect the fact that all side-effecting 1171 // properties have been reflected in the caller, and make the call HLO 1172 // removable. ClearCalledComputations()1173 void ClearCalledComputations() { called_computations_.clear(); } 1174 1175 // Returns true if this instruction performs an elementwise operation on 1176 // `operand_idx`-th operand. An instruction is elementwise on an operand iff, 1177 // after performing necessary implicit broadcast 1178 // (cs/IrArray::EmitArrayElementAddress), to compute the output at index 1179 // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is 1180 // the element at {i_0,i_1,...,i_n}. 1181 // 1182 // Note on performance: when this instruction is kFusion, this method, in the 1183 // worst case, scans all fused instructions. We could speed this up by 1184 // caching. 1185 bool IsElementwiseOnOperand(int64 operand_idx) const; 1186 1187 // Returns true if this instruction is elementwise on all its operands. 1188 bool IsElementwise() const; 1189 1190 // Returns true if this elementwise instruction implicitly broadcasts operand 1191 // `operand_idx`. 1192 // 1193 // Precondition: this instruction should be an elementwise operation. 1194 bool ImplicitlyBroadcastsOperand(int64 operand_idx) const; 1195 1196 // Returns true if this instruction is binary and elementwise. 1197 bool IsElementwiseBinary() const; 1198 1199 // Returns whether this instruction may reuse elements of its `i`th operand. ReusesOperandElements(int64 i)1200 bool ReusesOperandElements(int64 i) const { 1201 return OperandElementUse(i) == UseKind::kReuse; 1202 } 1203 1204 // Returns the indices that the given operand appear in the operand list of 1205 // this instruction. Note that an instruction can use the same operand 1206 // multiple times. 1207 std::vector<int64> OperandIndices(const HloInstruction* operand) const; 1208 1209 // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If 1210 // this reshape merely inserts or deletes 1-sized dimensions, return the input 1211 // indices of the deleted dimensions and the output indices of the inserted 1212 // dimensions. 1213 // 1214 // Precondition: this op must be a reshape. 1215 std::tuple<bool, std::vector<int64>, std::vector<int64>> 1216 ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; 1217 1218 // Gets/sets the string identifier for this instruction. name()1219 const string& name() const { return name_; } set_name(tensorflow::StringPiece name)1220 void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); } 1221 1222 // Use the given NameUniquer to select a unique name for the instruction based 1223 // on the instruction's existing name. 1224 void UniquifyName(NameUniquer* name_uniquer); 1225 1226 // Set the unique id for this instruction to "id" SetUniqueId(int id)1227 void SetUniqueId(int id) { 1228 CHECK_EQ(unique_id_, -1); // Should not be assigned already 1229 CHECK_GE(id, 0); 1230 unique_id_ = id; 1231 } 1232 1233 // Return the unique ID assigned to this node via SetUniqueId (or -1 1234 // if no id has been assigned yet). unique_id()1235 int unique_id() const { return unique_id_; } 1236 1237 // Sets the debug metadata for this instruction. set_metadata(const OpMetadata & metadata)1238 void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } metadata()1239 const OpMetadata& metadata() const { return metadata_; } 1240 1241 // Set/get the computation containing this instruction. set_parent should only 1242 // be called by HloComputation methods which add/remove instructions to 1243 // computations. set_parent(HloComputation * computation)1244 void set_parent(HloComputation* computation) { parent_ = computation; } parent()1245 const HloComputation* parent() const { return parent_; } parent()1246 HloComputation* parent() { return parent_; } 1247 1248 // Returns the module for this instruction. 1249 HloModule* GetModule() const; 1250 1251 // Returns whether we could assign input and output layouts to this 1252 // instruction to make it a bitcast. 1253 bool CouldBeBitcast() const; 1254 1255 // Get/Set the number of partitions per outer dimension (in order, starting 1256 // with outer-most dimension first). Currently used by the parallel cpu 1257 // backend to partition HLOs into parallel tasks. 1258 // TODO(b/62783254) Replace these methods with a more general way to 1259 // annotate HLOs with backend-specific information. outer_dimension_partitions()1260 const std::vector<int64>& outer_dimension_partitions() const { 1261 return outer_dimension_partitions_; 1262 } 1263 void set_outer_dimension_partitions( 1264 const std::vector<int64>& outer_dimension_partitions); 1265 1266 // Change the layout for an Constant Hlo instruction to match new_layout. For 1267 // tuple shaped constants shape_index is the path to the internal array 1268 // subshape whose layout needs to be changed. 1269 void RelayoutConstant(const Layout& new_layout, 1270 const ShapeIndex& shape_index = {}); 1271 1272 private: 1273 enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; 1274 1275 // Helper class for computing OperandElementUse for kFusion. 1276 class FusionReusesParamElements; 1277 1278 // See comments on Identical(). 1279 // eq_shapes() is used to check shapes for equality, and would normally be 1280 // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on 1281 // whether we want a layout-sensitive check or not. 1282 bool IdenticalSlowPath( 1283 const HloInstruction& other, 1284 const std::function<bool(const HloComputation*, const HloComputation*)>& 1285 eq_computations, 1286 const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const; 1287 1288 // Creates an n-ary elementwise operation. 1289 static std::unique_ptr<HloInstruction> CreateNary( 1290 const Shape& shape, HloOpcode opcode, 1291 tensorflow::gtl::ArraySlice<HloInstruction*> operands); 1292 1293 // Appends operand to the list of operands and adds this instruction as a user 1294 // of the operand. 1295 void AppendOperand(HloInstruction* operand); 1296 1297 // Adds a user for this instruction. 1298 void AddUser(HloInstruction* user); 1299 1300 // Removes a user for this instruction. 1301 void RemoveUser(HloInstruction* user); 1302 1303 // Internal constructor for a given opcode/shape, other fields must be filled 1304 // by factory methods. 1305 HloInstruction(HloOpcode opcode, const Shape& shape); 1306 1307 // Fuses the given instruction into this fusion instruction. When add_output 1308 // is false (which is the default), instruction_to_fuse is cloned and the 1309 // clone is placed in the fusion instruction. instruction_to_fuse is 1310 // unchanged. 1311 // 1312 // When add_output is true, a clone of the instruction_to_fuse will be part 1313 // of the output of fusion instructions. The users of instruction_to_fuse 1314 // will be redirected to this fusion instructions. instruction_to_fuse will 1315 // be removed from its parent computation. 1316 // 1317 // Precondition: this->opcode() == HloOpcode::kFusion 1318 HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, 1319 bool add_output = false); 1320 1321 // Clones the given instruction_to_fuse and insert the clone into this fusion 1322 // instruction. If add_output is true, a clone of instruction_to_fuse will 1323 // be in the output of the this fusion instruction (part of the tuple of the 1324 // fusion root). 1325 // 1326 // Precondition: opcode() == HloOpcode::kFusion 1327 HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, 1328 bool add_output = false); 1329 1330 // Clones a fusion instruction with a new shape and operands. 1331 std::unique_ptr<HloInstruction> CloneFusionWithNewOperands( 1332 const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, 1333 HloModule* module = nullptr) const; 1334 1335 // Returns true if this instruction can legally have the dimensions field 1336 // set. Used for checking precondition of dimensions field accessors. 1337 bool CanHaveDimensionsField() const; 1338 1339 // Returns how this instruction uses elements of its `i`th operand. 1340 UseKind OperandElementUse(int64 i) const; 1341 1342 int unique_id_; // Unique to this HloInstruction within a HloModule 1343 1344 // Opcode for this instruction. 1345 HloOpcode opcode_; 1346 1347 // Instruction operands. 1348 InstructionVector operands_; 1349 1350 // The set of control predecessors of this instruction. 1351 std::vector<HloInstruction*> control_predecessors_; 1352 1353 // The users of this instruction. Users are HLOs where this instruction is an 1354 // operand. The vector users_ and the set user_set_ contain identical 1355 // members. The set enables fast membership testing and the vector enables 1356 // fast, stable iteration. 1357 std::vector<HloInstruction*> users_; 1358 std::unordered_set<const HloInstruction*> user_set_; 1359 1360 // The set of control successors of this instruction. 1361 std::vector<HloInstruction*> control_successors_; 1362 1363 // The computation in which this instruction is contained. 1364 HloComputation* parent_ = nullptr; 1365 1366 // Shape of outfeed request. 1367 Shape outfeed_shape_; 1368 1369 // Result shape of this instruction. 1370 Shape shape_; 1371 1372 // Literal, only present for kConstant. 1373 std::unique_ptr<Literal> literal_; 1374 1375 // Constant index, only present for kGetTupleElement. 1376 int64 tuple_index_ = -1; 1377 1378 // Dimensions present for some operations that require reshaping or 1379 // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. 1380 std::vector<int64> dimensions_; 1381 1382 // Describes the window in a windowed operation such as convolution. 1383 std::unique_ptr<Window> window_; 1384 1385 // Describes the dimension numbers used for a convolution. 1386 std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; 1387 1388 // Describes the dimension numbers used for a dot. 1389 std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; 1390 1391 std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; 1392 std::vector<int64> gather_window_bounds_; 1393 1394 // Describes FFT type for an FFT instruction. 1395 FftType fft_type_ = FftType::FFT; 1396 1397 // Indicates the FFT length for an FFT instruction. 1398 std::vector<int64> fft_length_; 1399 1400 // Describes the [begin, end) index range for a slice. 1401 std::vector<int64> slice_starts_; 1402 std::vector<int64> slice_limits_; 1403 std::vector<int64> slice_strides_; 1404 1405 // Describes whether the slice can be lowered to an offset into the operand. 1406 bool is_in_place_slice_ = false; 1407 1408 // The bit sizes for a reduce-precision operation. 1409 int32 exponent_bits_ = 0; 1410 int32 mantissa_bits_ = 0; 1411 1412 // Describes the [start, start + size) range size for a dynamic slice 1413 // ('start' is specified dynamically in the second operand of the operation). 1414 std::vector<int64> dynamic_slice_sizes_; 1415 1416 // The padding configuration that describes the edge padding and interior 1417 // padding of this pad instruction. Only set for pad instructions. 1418 std::unique_ptr<PaddingConfig> padding_config_; 1419 1420 // The type of the fusion. Used by kFusion only. 1421 FusionKind fusion_kind_; 1422 1423 // The sharding, if one exists. 1424 std::unique_ptr<HloSharding> sharding_; 1425 1426 // For parameter instructions this field holds the parameter number. 1427 int64 parameter_number_ = 0; 1428 1429 // Name of a global symbol to call, only present for kCustomCall. 1430 string custom_call_target_; 1431 1432 // Name to use for host send/recv channels, only present for kHostCompute. 1433 string channel_name_; 1434 1435 // Estimate of the duration of a host computation in nanoseconds. 1436 int64 cost_estimate_ns_; 1437 1438 // Computations called by this instruction. 1439 std::vector<HloComputation*> called_computations_; 1440 1441 // Indices of computations in called_computations_ for instructions which call 1442 // multiple computations. 1443 enum { 1444 // kWhile computations. 1445 kBodyComputationIndex = 0, 1446 kConditionComputationIndex = 1, 1447 1448 // kSelectAndScatter computations. 1449 kSelectComputationIndex = 0, 1450 kScatterComputationIndex = 1, 1451 1452 // kConditional computations. 1453 kTrueComputationIndex = 0, 1454 kFalseComputationIndex = 1, 1455 }; 1456 1457 // Outfeed configuration information, only present for kOutfeed. 1458 string outfeed_config_; 1459 1460 // A trace instruction that consumes this instruction. 1461 // 1462 // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as 1463 // an operand. 1464 HloInstruction* trace_instruction_ = nullptr; 1465 1466 // The distribution requested for random number generation. 1467 // Only present for kRng. 1468 RandomDistribution distribution_; 1469 1470 // A small float number added to the variance to avoid divide-by-zero error. 1471 // Only present for kBatchNormTraining. 1472 float epsilon_ = 0.0f; 1473 1474 // An integer value representing the index of the feature dimension. 1475 // Only present for kBatchNormTraining. 1476 int64 feature_index_ = -1; 1477 1478 // Represents a unique identifier for each Send/Recv instruction pair. 1479 // Only present for kSend or kRecv. 1480 int64 channel_id_ = -1; 1481 1482 // The string representation of the infeed configuration. 1483 string infeed_config_; 1484 1485 // String identifier for instruction. 1486 string name_; 1487 1488 // Metadata for debugging. 1489 OpMetadata metadata_; 1490 1491 // The number of partitions per outer dimension (listed in order from 1492 // outer-most dimension first). 1493 std::vector<int64> outer_dimension_partitions_; 1494 1495 TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); 1496 }; 1497 1498 string ToString(HloInstruction::FusionKind kind); 1499 StatusOr<HloInstruction::FusionKind> StringToFusionKind( 1500 const string& kind_name); 1501 1502 // Custom (de)stringification functions for protos that live inside 1503 // HloInstruction. 1504 string PaddingConfigToString(const PaddingConfig& padding); 1505 string OpMetadataToString(const OpMetadata& metadata); 1506 string RandomDistributionToString(const RandomDistribution& distribution); 1507 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); 1508 1509 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); 1510 1511 // Map classes that guarantee a deterministic iteration order when the key is 1512 // an HloInstruction* or a const HloInstruction*. 1513 // To make the iteration order over the map deterministic, the comparator 1514 // should not be using the pointer values, but rather an intrinsic property of 1515 // the hlo. 1516 // 1517 // Note that this cannot be used for HLO instructions across multiple modules 1518 // since the id of HLO instructions are only unique within each HLO module. 1519 struct HloPtrComparator { operatorHloPtrComparator1520 bool operator()(const HloInstruction* const& lhs, 1521 const HloInstruction* const& rhs) const { 1522 return lhs->unique_id() < rhs->unique_id(); 1523 } 1524 }; 1525 1526 template <typename ValueT> 1527 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>; 1528 1529 template <typename ValueT> 1530 using ConstHloInstructionMap = 1531 std::map<const HloInstruction*, ValueT, HloPtrComparator>; 1532 1533 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>; 1534 using ConstHloInstructionSet = 1535 std::set<const HloInstruction*, HloPtrComparator>; 1536 1537 } // namespace xla 1538 1539 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ 1540