Home
last modified time | relevance | path

Searched refs:PartitionedHlo (Results 1 – 9 of 9) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner.h258 class PartitionedHlo {
271 std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache;
293 PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) in PartitionedHlo() function
307 PartitionedHlo Reshard(const HloSharding& target);
313 PartitionedHlo PadWithValue(HloInstruction* pad_value,
338 PartitionedHlo Replicate();
346 PartitionedHlo ReshardNoCache(const HloSharding& target);
349 PartitionedHlo Broadcast() const;
353 PartitionedHlo ReshardWithAllToAll(
358 PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
[all …]
Dgather_scatter_handler.cc35 const PartitionedHlo& operand, absl::Span<const int64> index_map, in GatherScatterOperandPartitionedOnlyOnTrivialSliceDims()
54 const PartitionedHlo& operand, const PartitionedHlo& replicated_indices, in IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims()
127 PartitionedHlo& operand,
128 PartitionedHlo& indices,
138 PartitionedHlo& operand, PartitionedHlo& indices, SpmdBuilder* b) { in PartitionIndexOnlyPartition()
168 return PartitionedHlo(pgather, gather->shape(), operand.state()) in PartitionIndexOnlyPartition()
182 PartitionedHlo& operand, PartitionedHlo& indices, in ParititonPassthroughOperand()
205 return PartitionedHlo(pgather, output_shape, operand.state()) in ParititonPassthroughOperand()
217 PartitionedHlo& operand, PartitionedHlo& indices, in ParititonTrivialIndexedOperandDimension()
303 return PartitionedHlo(ar, output_shape, operand.state()) in ParititonTrivialIndexedOperandDimension()
[all …]
Dconvolution_handler.h30 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
Dconvolution_handler.cc42 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionConvolutionWithBatchGroupCount()
127 return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) in PartitionConvolutionWithBatchGroupCount()
134 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionConvolutionWithFeatureGroupCount()
219 return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) in PartitionConvolutionWithFeatureGroupCount()
228 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
506 return PartitionedHlo(ar, output_base_shape, lhs.state()) in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
515 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
558 rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state()); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
733 return PartitionedHlo(ar, output_base_shape, lhs.state()) in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
741 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionConvolutionTiledOutput()
[all …]
Dspmd_partitioner.cc345 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { in Reshard()
369 PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { in ReshardNoCache()
394 PartitionedHlo( in ReshardNoCache()
402 return PartitionedHlo(tuple, base_shape_, state_); in ReshardNoCache()
451 return PartitionedHlo(copy, base_shape_, state_); in ReshardNoCache()
464 return PartitionedHlo(partially_sharded, base_shape(), state_); in ReshardNoCache()
476 return PartitionedHlo(slice, base_shape_, state_); in ReshardNoCache()
479 PartitionedHlo PartitionedHlo::PadWithValue( in PadWithValue()
542 return PartitionedHlo(result, base_shape_, state_); in PadWithValue()
545 absl::optional<PartitionedHlo::WindowedInputShardReturnValue>
[all …]
Ddot_handler.cc467 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionBaseCase()
527 return PartitionedHlo(dot, output_base_shape, lhs.state()) in PartitionBaseCase()
835 PartitionedHlo(padded_slice_operand, in PartitionBaseCase()
1094 PartitionedHlo(slice_operand, slice_operand->shape(), state) in PartitionBaseCase()
1523 return PartitionedHlo(ar, output_base_shape, lhs.state()) in PartitionBaseCase()
1637 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape,
1649 PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, in PartitionDotGroupOnBatch()
1705 PartitionedHlo per_group_lhs = lhs; in PartitionDotGroupOnBatch()
1706 PartitionedHlo per_group_rhs = rhs; in PartitionDotGroupOnBatch()
1735 per_group_lhs = PartitionedHlo( in PartitionDotGroupOnBatch()
[all …]
Dspmd_partitioner_util.h267 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
346 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState(
347 const PartitionedHlo::PartitioningState& state,
Dfft_handler.cc430 PartitionedHlo(result, hlo->shape(), partitioned_input.state()); in HandleFft()
Dspmd_partitioner_util.cc1071 HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original, in HaloExchangeToPadOnLeft()
1671 PartitionedHlo::PartitioningState CreatePerGroupPartitioningState( in CreatePerGroupPartitioningState()
1672 const PartitionedHlo::PartitioningState& state, in CreatePerGroupPartitioningState()
1687 grouped_cache = absl::make_unique<PartitionedHlo::ReshardCache>(); in CreatePerGroupPartitioningState()