Searched refs:device_groups (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | spmd_partitioner_util.cc | 1398 std::vector<std::vector<int64>> device_groups(Product(group_dim_sizes)); in GroupShardingOnDims() local 1407 device_groups[group_id].push_back(device); in GroupShardingOnDims() 1410 std::move(device_groups), in GroupShardingOnDims() 1441 if (grouped_sharding.device_groups[0].size() != 1) { in UngroupSharding() 1443 tiling_dims.push_back(grouped_sharding.device_groups[0].size()); in UngroupSharding() 1465 for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) { in UngroupSharding() 1475 tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device]; in UngroupSharding() 1501 CHECK_EQ(grouped_sharding.device_groups.size(), in AlignGroupsWith() 1502 reference.device_groups.size()); in AlignGroupsWith() 1504 for (int64 g = 0; g < reference.device_groups.size(); ++g) { in AlignGroupsWith() [all …]
|
D | spmd_partitioner_util.h | 298 GroupedSharding(std::vector<std::vector<int64>> device_groups, in GroupedSharding() 302 : device_groups(std::move(device_groups)), in GroupedSharding() 307 std::vector<std::vector<int64>> device_groups; member 348 const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b); 354 const std::vector<std::vector<int64>>& device_groups, 400 const std::vector<std::vector<int64>>& device_groups);
|
D | spmd_partitioner.cc | 461 hlo_, state_.partition_id, target_grouped.device_groups, group_dims, in ReshardNoCache() 959 partitioned_hlo.state(), sharding_grouped.device_groups, in ReshardToPartialReplicateWithAllGather() 1560 MakePartitioningState(), grouped.device_groups, &b_); in HandleConcatenate() 2773 inputs[0].state(), grouped.device_groups, &b_); in HandleReduce() 3088 MakePartitioningState(), sharding_grouped.device_groups, &b_); in HandleRng() 3433 std::vector<ReplicaGroup> device_groups; in GetDefaultCollectiveOpsCreator() local 3434 device_groups.reserve(partition_subgroups.size() * num_replicas); in GetDefaultCollectiveOpsCreator() 3437 device_groups.emplace_back(); in GetDefaultCollectiveOpsCreator() 3439 device_groups.back().add_replica_ids(i * num_partitions + pid); in GetDefaultCollectiveOpsCreator() 3444 operand->shape(), {operand}, reduction, device_groups, in GetDefaultCollectiveOpsCreator() [all …]
|
D | dot_handler.cc | 1728 lhs.state(), lhs_grouped.device_groups, b); in PartitionDotGroupOnBatch() 1743 lhs.state(), output_grouped.device_groups, b); in PartitionDotGroupOnBatch() 1756 output_grouped.device_groups, batch_dims, in PartitionDotGroupOnBatch() 1871 num_partitions / output_grouped.device_groups.size(), in PartitionDotGroupOnBatch() 1943 other_sharding, output_grouped.device_groups)) { in GetNonContractingPartitionGroupedShardingForOtherOperand() 2034 matching.state(), matching_grouped.device_groups, b); in PartitionDotGroupOnNonContracting() 2069 num_partitions / matching_grouped.device_groups.size(), in PartitionDotGroupOnNonContracting() 2192 output_sharding, lhs_grouped.device_groups)) { in PartitionDotGroupOnContracting() 2221 lhs.state(), lhs_grouped.device_groups, b); in PartitionDotGroupOnContracting() 2764 lhs.state(), grouped_output.device_groups, b); in PartitionDot()
|
D | gather_scatter_handler.cc | 418 operand.state(), grouped_operand.device_groups, b); in PartitionIndexParallelDimensions() 585 indices.state(), sharding_grouped.device_groups, &b_); in HandleScatter()
|