Searched refs:PartitionedHlo (Results 1 – 9 of 9) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | spmd_partitioner.h | 258 class PartitionedHlo { 271 std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache; 293 PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) in PartitionedHlo() function 307 PartitionedHlo Reshard(const HloSharding& target); 313 PartitionedHlo PadWithValue(HloInstruction* pad_value, 338 PartitionedHlo Replicate(); 346 PartitionedHlo ReshardNoCache(const HloSharding& target); 349 PartitionedHlo Broadcast() const; 353 PartitionedHlo ReshardWithAllToAll( 358 PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; [all …]
|
D | gather_scatter_handler.cc | 35 const PartitionedHlo& operand, absl::Span<const int64> index_map, in GatherScatterOperandPartitionedOnlyOnTrivialSliceDims() 54 const PartitionedHlo& operand, const PartitionedHlo& replicated_indices, in IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims() 127 PartitionedHlo& operand, 128 PartitionedHlo& indices, 138 PartitionedHlo& operand, PartitionedHlo& indices, SpmdBuilder* b) { in PartitionIndexOnlyPartition() 168 return PartitionedHlo(pgather, gather->shape(), operand.state()) in PartitionIndexOnlyPartition() 182 PartitionedHlo& operand, PartitionedHlo& indices, in ParititonPassthroughOperand() 205 return PartitionedHlo(pgather, output_shape, operand.state()) in ParititonPassthroughOperand() 217 PartitionedHlo& operand, PartitionedHlo& indices, in ParititonTrivialIndexedOperandDimension() 303 return PartitionedHlo(ar, output_shape, operand.state()) in ParititonTrivialIndexedOperandDimension() [all …]
|
D | convolution_handler.h | 30 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
|
D | convolution_handler.cc | 42 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionConvolutionWithBatchGroupCount() 127 return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) in PartitionConvolutionWithBatchGroupCount() 134 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionConvolutionWithFeatureGroupCount() 219 return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) in PartitionConvolutionWithFeatureGroupCount() 228 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 506 return PartitionedHlo(ar, output_base_shape, lhs.state()) in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 515 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() 558 rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state()); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() 733 return PartitionedHlo(ar, output_base_shape, lhs.state()) in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() 741 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionConvolutionTiledOutput() [all …]
|
D | spmd_partitioner.cc | 345 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { in Reshard() 369 PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { in ReshardNoCache() 394 PartitionedHlo( in ReshardNoCache() 402 return PartitionedHlo(tuple, base_shape_, state_); in ReshardNoCache() 451 return PartitionedHlo(copy, base_shape_, state_); in ReshardNoCache() 464 return PartitionedHlo(partially_sharded, base_shape(), state_); in ReshardNoCache() 476 return PartitionedHlo(slice, base_shape_, state_); in ReshardNoCache() 479 PartitionedHlo PartitionedHlo::PadWithValue( in PadWithValue() 542 return PartitionedHlo(result, base_shape_, state_); in PadWithValue() 545 absl::optional<PartitionedHlo::WindowedInputShardReturnValue> [all …]
|
D | dot_handler.cc | 467 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionBaseCase() 527 return PartitionedHlo(dot, output_base_shape, lhs.state()) in PartitionBaseCase() 835 PartitionedHlo(padded_slice_operand, in PartitionBaseCase() 1094 PartitionedHlo(slice_operand, slice_operand->shape(), state) in PartitionBaseCase() 1523 return PartitionedHlo(ar, output_base_shape, lhs.state()) in PartitionBaseCase() 1637 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, 1649 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionDotGroupOnBatch() 1705 PartitionedHlo per_group_lhs = lhs; in PartitionDotGroupOnBatch() 1706 PartitionedHlo per_group_rhs = rhs; in PartitionDotGroupOnBatch() 1735 per_group_lhs = PartitionedHlo( in PartitionDotGroupOnBatch() [all …]
|
D | spmd_partitioner_util.h | 267 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original, 346 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState( 347 const PartitionedHlo::PartitioningState& state,
|
D | fft_handler.cc | 430 PartitionedHlo(result, hlo->shape(), partitioned_input.state()); in HandleFft()
|
D | spmd_partitioner_util.cc | 1071 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original, in HaloExchangeToPadOnLeft() 1671 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState( in CreatePerGroupPartitioningState() 1672 const PartitionedHlo::PartitioningState& state, in CreatePerGroupPartitioningState() 1687 grouped_cache = absl::make_unique<PartitionedHlo::ReshardCache>(); in CreatePerGroupPartitioningState()
|