Searched refs:Reshard (Results 1 – 5 of 5) sorted by relevance
169 .Reshard(gather->sharding()) in PartitionIndexOnlyPartition()190 indices = indices.Reshard(HloSharding::Replicate()); in ParititonPassthroughOperand()206 .Reshard(output_sharding) in ParititonPassthroughOperand()226 indices = indices.Reshard(HloSharding::Replicate()); in ParititonTrivialIndexedOperandDimension()304 .Reshard(output_sharding) in ParititonTrivialIndexedOperandDimension()356 operand = operand.Reshard(operand_sharding); in PartitionIndexParallelDimensions()357 indices = indices.Reshard(indices_sharding); in PartitionIndexParallelDimensions()451 .Reshard(output_sharding) in PartitionIndexParallelDimensions()577 updates = updates.Reshard(*new_updates_sharding); in HandleScatter()611 .Reshard(hlo->sharding()) in HandleScatter()[all …]
345 PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { in Reshard() function in xla::spmd::PartitionedHlo381 return Reshard(target.GetTupleSharding(shape).ValueOrDie()); in ReshardNoCache()396 .Reshard(target.GetSubSharding(shape, {i})) in ReshardNoCache()443 return Replicate().Reshard(target); in ReshardNoCache()769 return Reshard(target).ReshardAsWindowedInput(window, target, pad_value); in ReshardAsWindowedInput()1299 return partitioned_hlo.Reshard(target); in ReshardPartialReplicateWithAllToAll()1302 auto partitioned_hlo = Reshard(tmp_partial_replicate_sharding); in ReshardPartialReplicateWithAllToAll()1395 new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo()); in DefaultAction()1403 .Reshard(hlo->sharding())); in DefaultAction()1476 GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo()); in HandleElementwise()[all …]
114 rhs = rhs.Reshard(aligned_rhs_sharding); in PartitionConvolutionWithBatchGroupCount()116 lhs = lhs.Reshard(aligned_lhs_sharding); in PartitionConvolutionWithBatchGroupCount()128 .Reshard(output_sharding) in PartitionConvolutionWithBatchGroupCount()206 rhs = rhs.Reshard(aligned_rhs_sharding); in PartitionConvolutionWithFeatureGroupCount()208 lhs = lhs.Reshard(aligned_lhs_sharding); in PartitionConvolutionWithFeatureGroupCount()220 .Reshard(output_sharding) in PartitionConvolutionWithFeatureGroupCount()274 lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()281 rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()507 .Reshard(output_sharding) in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()583 lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()[all …]
528 .Reshard(output_sharding) in PartitionBaseCase()549 auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); in PartitionBaseCase()566 auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); in PartitionBaseCase()837 .Reshard(*padded_slice_sharding) in PartitionBaseCase()1095 .Reshard(*slice_sharding) in PartitionBaseCase()1504 lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); in PartitionBaseCase()1509 rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); in PartitionBaseCase()1524 .Reshard(output_sharding) in PartitionBaseCase()1532 auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); in PartitionBaseCase()1542 auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); in PartitionBaseCase()[all …]
307 PartitionedHlo Reshard(const HloSharding& target);