Home
last modified time | relevance | path

Searched refs:Reshard (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dgather_scatter_handler.cc169 .Reshard(gather->sharding()) in PartitionIndexOnlyPartition()
190 indices = indices.Reshard(HloSharding::Replicate()); in ParititonPassthroughOperand()
206 .Reshard(output_sharding) in ParititonPassthroughOperand()
226 indices = indices.Reshard(HloSharding::Replicate()); in ParititonTrivialIndexedOperandDimension()
304 .Reshard(output_sharding) in ParititonTrivialIndexedOperandDimension()
356 operand = operand.Reshard(operand_sharding); in PartitionIndexParallelDimensions()
357 indices = indices.Reshard(indices_sharding); in PartitionIndexParallelDimensions()
451 .Reshard(output_sharding) in PartitionIndexParallelDimensions()
577 updates = updates.Reshard(*new_updates_sharding); in HandleScatter()
611 .Reshard(hlo->sharding()) in HandleScatter()
[all …]
Dspmd_partitioner.cc345 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { in Reshard() function in xla::spmd::PartitionedHlo
381 return Reshard(target.GetTupleSharding(shape).ValueOrDie()); in ReshardNoCache()
396 .Reshard(target.GetSubSharding(shape, {i})) in ReshardNoCache()
443 return Replicate().Reshard(target); in ReshardNoCache()
769 return Reshard(target).ReshardAsWindowedInput(window, target, pad_value); in ReshardAsWindowedInput()
1299 return partitioned_hlo.Reshard(target); in ReshardPartialReplicateWithAllToAll()
1302 auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding); in ReshardPartialReplicateWithAllToAll()
1395 new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); in DefaultAction()
1403 .Reshard(hlo->sharding())); in DefaultAction()
1476 GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo()); in HandleElementwise()
[all …]
Dconvolution_handler.cc114 rhs = rhs.Reshard(aligned_rhs_sharding); in PartitionConvolutionWithBatchGroupCount()
116 lhs = lhs.Reshard(aligned_lhs_sharding); in PartitionConvolutionWithBatchGroupCount()
128 .Reshard(output_sharding) in PartitionConvolutionWithBatchGroupCount()
206 rhs = rhs.Reshard(aligned_rhs_sharding); in PartitionConvolutionWithFeatureGroupCount()
208 lhs = lhs.Reshard(aligned_lhs_sharding); in PartitionConvolutionWithFeatureGroupCount()
220 .Reshard(output_sharding) in PartitionConvolutionWithFeatureGroupCount()
274 lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
281 rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
507 .Reshard(output_sharding) in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
583 lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
[all …]
Ddot_handler.cc528 .Reshard(output_sharding) in PartitionBaseCase()
549 auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); in PartitionBaseCase()
566 auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); in PartitionBaseCase()
837 .Reshard(*padded_slice_sharding) in PartitionBaseCase()
1095 .Reshard(*slice_sharding) in PartitionBaseCase()
1504 lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); in PartitionBaseCase()
1509 rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); in PartitionBaseCase()
1524 .Reshard(output_sharding) in PartitionBaseCase()
1532 auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); in PartitionBaseCase()
1542 auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); in PartitionBaseCase()
[all …]
Dspmd_partitioner.h307 PartitionedHlo Reshard(const HloSharding& target);