Searched refs:GroupShardingOnDims (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | dot_handler.cc | 1701 GroupShardingOnDims(rhs.sharding(), rhs_dims), in PartitionDotGroupOnBatch() 1702 GroupShardingOnDims(lhs.sharding(), lhs_dims))); in PartitionDotGroupOnBatch() 1704 auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); in PartitionDotGroupOnBatch() 1708 auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims); in PartitionDotGroupOnBatch() 1709 auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims); in PartitionDotGroupOnBatch() 1721 GroupShardingOnDims( in PartitionDotGroupOnBatch() 1807 GroupShardingOnDims(operand.base_shape().rank() < in PartitionDotGroupOnBatch() 1898 GroupShardingOnDims(output_sharding, output_dims); in GetNonContractingPartitionGroupedShardingForMatchedOperand() 1902 GroupShardingOnDims( in GetNonContractingPartitionGroupedShardingForMatchedOperand() 1928 GroupShardingOnDims(output_sharding, output_dims); in GetNonContractingPartitionGroupedShardingForOtherOperand() [all …]
|
D | spmd_partitioner_util.h | 315 GroupedSharding GroupShardingOnDims(const HloSharding& sharding, 320 GroupedSharding GroupShardingOnDims(const HloSharding& sharding,
|
D | gather_scatter_handler.cc | 342 GroupShardingOnDims(indices_sharding, indices_parallel_dims); in PartitionIndexParallelDimensions() 344 GroupShardingOnDims(operand_sharding, operand_parallel_dims); in PartitionIndexParallelDimensions() 442 GroupShardingOnDims(gather_output_sharding, output_parallel_dims); in PartitionIndexParallelDimensions() 581 auto sharding_grouped = GroupShardingOnDims( in HandleScatter()
|
D | spmd_partitioner_util.cc | 1378 GroupedSharding GroupShardingOnDims(const HloSharding& sharding, in GroupShardingOnDims() function 1381 return GroupShardingOnDims(sharding, group_dims, group_dim_shards); in GroupShardingOnDims() 1384 GroupedSharding GroupShardingOnDims(const HloSharding& sharding, in GroupShardingOnDims() function 1551 auto sharding_grouped = GroupShardingOnDims(sharding, sharding_dims); in AlignShardingOnDims() 1552 auto reference_grouped = GroupShardingOnDims(reference, reference_dims); in AlignShardingOnDims()
|
D | spmd_partitioner.cc | 459 auto target_grouped = GroupShardingOnDims(target, group_dims); in ReshardNoCache() 957 GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors); in ReshardToPartialReplicateWithAllGather() 1558 auto grouped = GroupShardingOnDims(sharding, non_concat_dims); in HandleConcatenate() 2771 GroupShardingOnDims(inputs[0].sharding(), preserved_dims); in HandleReduce() 3086 auto sharding_grouped = GroupShardingOnDims(hlo->sharding(), group_dims); in HandleRng()
|