/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | spmd_partitioner_util.h | 45 SpmdBuilder* b); 47 HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b); 50 HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b); 54 SpmdBuilder* b) { in CreateR0WithType() 61 inline HloInstruction* CreateFirstWithType(PrimitiveType type, SpmdBuilder* b) { in CreateFirstWithType() 70 inline HloInstruction* CreateLastWithType(PrimitiveType type, SpmdBuilder* b) { in CreateLastWithType() 107 HloInstruction* partition_id, SpmdBuilder* b, 112 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b); 117 SpmdBuilder* b, 127 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b); [all …]
|
D | spmd_partitioner.h | 70 class SpmdBuilder : public HloComputation::Builder { 72 SpmdBuilder(const std::string& name, HloInstruction* hlo) in SpmdBuilder() function 115 std::function<HloInstruction*(SpmdBuilder*)> create_partition_id; 119 SpmdBuilder*, HloInstruction* operand, HloComputation* reduction, 126 SpmdBuilder*, HloInstruction* operand, 133 SpmdBuilder*, absl::Span<HloInstruction* const> operands, 141 SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape, 211 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 218 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, 233 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, [all …]
|
D | convolution_handler.h | 33 HloInstruction*, HloInstruction*, SpmdBuilder*, 37 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b);
|
D | fft_handler.cc | 55 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in PadEachPartitionWithHaloExchange() 114 SpmdBuilder* b) { in ShuffleWithinEachPartitionUsingOneHot() 162 int64* next_channel_id, SpmdBuilder* b) { in ShuffleDataWithAllToAll() 174 SpmdBuilder* b) { in GetCorrectionFactor() 231 HloModule* module, SpmdBuilder* b) { in GetFinalFftUsingCollectivePermute() 239 SpmdBuilder body_b("fft_collective_permute_body", hlo); in GetFinalFftUsingCollectivePermute() 312 SpmdBuilder cond_b("fft_collective_permute_condition", hlo); in GetFinalFftUsingCollectivePermute() 341 SpmdBuilder* b) { in SliceValidData()
|
D | convolution_handler.cc | 45 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionConvolutionWithBatchGroupCount() argument 48 int64 num_partitions, SpmdBuilder* b) { in PartitionConvolutionWithBatchGroupCount() 137 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionConvolutionWithFeatureGroupCount() argument 140 int64 num_partitions, SpmdBuilder* b) { in PartitionConvolutionWithFeatureGroupCount() 231 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() argument 234 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 518 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() argument 521 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() 744 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionConvolutionTiledOutput() argument 746 const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) { in PartitionConvolutionTiledOutput() [all …]
|
D | spmd_partitioner_util.cc | 55 SpmdBuilder* b) { in CreateConstant() 71 HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { in CreateZero() 92 HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b) { in CreateOne() 206 HloInstruction* partition_id, SpmdBuilder* b, in MakePartitionOffsets() 240 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { in MakeTiledPartitionOrdinals() 251 SpmdBuilder* b, HloComputation* computation) { in PadToShape() 294 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) { in PadBaseShapeBeforeUnevenTiledSharding() 413 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in TileToPartialReplicateHaloExchange() 507 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in PadFromPartialReplicateShape() 683 HloInstruction* shard_ordinal, SpmdBuilder* b) const { in Calculate() [all …]
|
D | dot_handler.cc | 79 [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b, in HandleDot() 471 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionBaseCase() argument 480 const SpmdPartitionerOptions& options, SpmdBuilder* b, in PartitionBaseCase() 659 SpmdBuilder body_b("windowed_dot_general_body", original_hlo); in PartitionBaseCase() 1349 SpmdBuilder cp_b("window_collective_permute", original_hlo); in PartitionBaseCase() 1367 SpmdBuilder ncp_b("last_iteration_noop", original_hlo); in PartitionBaseCase() 1401 SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo); in PartitionBaseCase() 1641 HloInstruction*, HloInstruction*, SpmdBuilder*, 1644 const SpmdPartitionerOptions& options, SpmdBuilder* b, 1655 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionDotGroupOnBatch() argument [all …]
|
D | gather_scatter_handler.cc | 56 int64 index_vector_dim, SpmdBuilder* b) { in IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims() 138 PartitionedHlo& operand, PartitionedHlo& indices, SpmdBuilder* b) { in PartitionIndexOnlyPartition() 184 SpmdBuilder* b = visitor->builder(); in ParititonPassthroughOperand() 219 SpmdBuilder* b = visitor->builder(); in ParititonTrivialIndexedOperandDimension() 326 SpmdBuilder* b = visitor->builder(); in PartitionIndexParallelDimensions()
|
D | spmd_partitioner.cc | 216 HloInstruction* SpmdBuilder::AddInstruction( in AddInstruction() 1362 b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)), in SpmdPartitioningVisitor() 2198 SpmdBuilder true_b("true_computation", visiting_hlo_); in HandleSingleDevice() 2213 SpmdBuilder false_b("false_computation", visiting_hlo_); in HandleSingleDevice() 2567 SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_); in HandleInfeed() 2971 SpmdBuilder branch_b(absl::StrCat("outfeed_branch_", i), visiting_hlo_); in HandleOutfeed() 3414 [](SpmdBuilder* b) { in GetDefaultCollectiveOpsCreator() 3418 SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction, in GetDefaultCollectiveOpsCreator() 3448 [](SpmdBuilder* b, HloInstruction* operand, in GetDefaultCollectiveOpsCreator() 3454 [](SpmdBuilder* b, absl::Span<HloInstruction* const> operands, in GetDefaultCollectiveOpsCreator() [all …]
|