Home
last modified time | relevance | path

Searched refs:SpmdBuilder (Results 1 – 9 of 9) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner_util.h45 SpmdBuilder* b);
47 HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b);
50 HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b);
54 SpmdBuilder* b) { in CreateR0WithType()
61 inline HloInstruction* CreateFirstWithType(PrimitiveType type, SpmdBuilder* b) { in CreateFirstWithType()
70 inline HloInstruction* CreateLastWithType(PrimitiveType type, SpmdBuilder* b) { in CreateLastWithType()
107 HloInstruction* partition_id, SpmdBuilder* b,
112 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b);
117 SpmdBuilder* b,
127 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b);
[all …]
Dspmd_partitioner.h70 class SpmdBuilder : public HloComputation::Builder {
72 SpmdBuilder(const std::string& name, HloInstruction* hlo) in SpmdBuilder() function
115 std::function<HloInstruction*(SpmdBuilder*)> create_partition_id;
119 SpmdBuilder*, HloInstruction* operand, HloComputation* reduction,
126 SpmdBuilder*, HloInstruction* operand,
133 SpmdBuilder*, absl::Span<HloInstruction* const> operands,
141 SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape,
211 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
218 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
233 SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding,
[all …]
Dconvolution_handler.h33 HloInstruction*, HloInstruction*, SpmdBuilder*,
37 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b);
Dfft_handler.cc55 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in PadEachPartitionWithHaloExchange()
114 SpmdBuilder* b) { in ShuffleWithinEachPartitionUsingOneHot()
162 int64* next_channel_id, SpmdBuilder* b) { in ShuffleDataWithAllToAll()
174 SpmdBuilder* b) { in GetCorrectionFactor()
231 HloModule* module, SpmdBuilder* b) { in GetFinalFftUsingCollectivePermute()
239 SpmdBuilder body_b("fft_collective_permute_body", hlo); in GetFinalFftUsingCollectivePermute()
312 SpmdBuilder cond_b("fft_collective_permute_condition", hlo); in GetFinalFftUsingCollectivePermute()
341 SpmdBuilder* b) { in SliceValidData()
Dconvolution_handler.cc45 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionConvolutionWithBatchGroupCount() argument
48 int64 num_partitions, SpmdBuilder* b) { in PartitionConvolutionWithBatchGroupCount()
137 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionConvolutionWithFeatureGroupCount() argument
140 int64 num_partitions, SpmdBuilder* b) { in PartitionConvolutionWithFeatureGroupCount()
231 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() argument
234 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
518 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() argument
521 HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
744 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionConvolutionTiledOutput() argument
746 const Window& conv_window, HloInstruction* original_hlo, SpmdBuilder* b) { in PartitionConvolutionTiledOutput()
[all …]
Dspmd_partitioner_util.cc55 SpmdBuilder* b) { in CreateConstant()
71 HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { in CreateZero()
92 HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b) { in CreateOne()
206 HloInstruction* partition_id, SpmdBuilder* b, in MakePartitionOffsets()
240 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { in MakeTiledPartitionOrdinals()
251 SpmdBuilder* b, HloComputation* computation) { in PadToShape()
294 HloInstruction* hlo, const HloSharding& sharding, SpmdBuilder* b) { in PadBaseShapeBeforeUnevenTiledSharding()
413 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in TileToPartialReplicateHaloExchange()
507 int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { in PadFromPartialReplicateShape()
683 HloInstruction* shard_ordinal, SpmdBuilder* b) const { in Calculate()
[all …]
Ddot_handler.cc79 [&](HloInstruction* l, HloInstruction* r, SpmdBuilder* b, in HandleDot()
471 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionBaseCase() argument
480 const SpmdPartitionerOptions& options, SpmdBuilder* b, in PartitionBaseCase()
659 SpmdBuilder body_b("windowed_dot_general_body", original_hlo); in PartitionBaseCase()
1349 SpmdBuilder cp_b("window_collective_permute", original_hlo); in PartitionBaseCase()
1367 SpmdBuilder ncp_b("last_iteration_noop", original_hlo); in PartitionBaseCase()
1401 SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo); in PartitionBaseCase()
1641 HloInstruction*, HloInstruction*, SpmdBuilder*,
1644 const SpmdPartitionerOptions& options, SpmdBuilder* b,
1655 HloInstruction*, HloInstruction*, SpmdBuilder*, in PartitionDotGroupOnBatch() argument
[all …]
Dgather_scatter_handler.cc56 int64 index_vector_dim, SpmdBuilder* b) { in IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims()
138 PartitionedHlo& operand, PartitionedHlo& indices, SpmdBuilder* b) { in PartitionIndexOnlyPartition()
184 SpmdBuilder* b = visitor->builder(); in ParititonPassthroughOperand()
219 SpmdBuilder* b = visitor->builder(); in ParititonTrivialIndexedOperandDimension()
326 SpmdBuilder* b = visitor->builder(); in PartitionIndexParallelDimensions()
Dspmd_partitioner.cc216 HloInstruction* SpmdBuilder::AddInstruction( in AddInstruction()
1362 b_(SpmdBuilder(computation->name() + "_spmd", /*hlo=*/nullptr)), in SpmdPartitioningVisitor()
2198 SpmdBuilder true_b("true_computation", visiting_hlo_); in HandleSingleDevice()
2213 SpmdBuilder false_b("false_computation", visiting_hlo_); in HandleSingleDevice()
2567 SpmdBuilder branch_b(absl::StrCat("infeed_branch_", i), visiting_hlo_); in HandleInfeed()
2971 SpmdBuilder branch_b(absl::StrCat("outfeed_branch_", i), visiting_hlo_); in HandleOutfeed()
3414 [](SpmdBuilder* b) { in GetDefaultCollectiveOpsCreator()
3418 SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction, in GetDefaultCollectiveOpsCreator()
3448 [](SpmdBuilder* b, HloInstruction* operand, in GetDefaultCollectiveOpsCreator()
3454 [](SpmdBuilder* b, absl::Span<HloInstruction* const> operands, in GetDefaultCollectiveOpsCreator()
[all …]