Home
last modified time | relevance | path

Searched refs:GroupShardingOnDims (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Ddot_handler.cc1701 GroupShardingOnDims(rhs.sharding(), rhs_dims), in PartitionDotGroupOnBatch()
1702 GroupShardingOnDims(lhs.sharding(), lhs_dims))); in PartitionDotGroupOnBatch()
1704 auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); in PartitionDotGroupOnBatch()
1708 auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims); in PartitionDotGroupOnBatch()
1709 auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims); in PartitionDotGroupOnBatch()
1721 GroupShardingOnDims( in PartitionDotGroupOnBatch()
1807 GroupShardingOnDims(operand.base_shape().rank() < in PartitionDotGroupOnBatch()
1898 GroupShardingOnDims(output_sharding, output_dims); in GetNonContractingPartitionGroupedShardingForMatchedOperand()
1902 GroupShardingOnDims( in GetNonContractingPartitionGroupedShardingForMatchedOperand()
1928 GroupShardingOnDims(output_sharding, output_dims); in GetNonContractingPartitionGroupedShardingForOtherOperand()
[all …]
Dspmd_partitioner_util.h315 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
320 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
Dgather_scatter_handler.cc342 GroupShardingOnDims(indices_sharding, indices_parallel_dims); in PartitionIndexParallelDimensions()
344 GroupShardingOnDims(operand_sharding, operand_parallel_dims); in PartitionIndexParallelDimensions()
442 GroupShardingOnDims(gather_output_sharding, output_parallel_dims); in PartitionIndexParallelDimensions()
581 auto sharding_grouped = GroupShardingOnDims( in HandleScatter()
Dspmd_partitioner_util.cc1378 GroupedSharding GroupShardingOnDims(const HloSharding& sharding, in GroupShardingOnDims() function
1381 return GroupShardingOnDims(sharding, group_dims, group_dim_shards); in GroupShardingOnDims()
1384 GroupedSharding GroupShardingOnDims(const HloSharding& sharding, in GroupShardingOnDims() function
1551 auto sharding_grouped = GroupShardingOnDims(sharding, sharding_dims); in AlignShardingOnDims()
1552 auto reference_grouped = GroupShardingOnDims(reference, reference_dims); in AlignShardingOnDims()
Dspmd_partitioner.cc459 auto target_grouped = GroupShardingOnDims(target, group_dims); in ReshardNoCache()
957 GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors); in ReshardToPartialReplicateWithAllGather()
1558 auto grouped = GroupShardingOnDims(sharding, non_concat_dims); in HandleConcatenate()
2771 GroupShardingOnDims(inputs[0].sharding(), preserved_dims); in HandleReduce()
3086 auto sharding_grouped = GroupShardingOnDims(hlo->sharding(), group_dims); in HandleRng()