Home
last modified time | relevance | path

Searched refs:device_groups (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner_util.cc1398 std::vector<std::vector<int64>> device_groups(Product(group_dim_sizes)); in GroupShardingOnDims() local
1407 device_groups[group_id].push_back(device); in GroupShardingOnDims()
1410 std::move(device_groups), in GroupShardingOnDims()
1441 if (grouped_sharding.device_groups[0].size() != 1) { in UngroupSharding()
1443 tiling_dims.push_back(grouped_sharding.device_groups[0].size()); in UngroupSharding()
1465 for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) { in UngroupSharding()
1475 tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device]; in UngroupSharding()
1501 CHECK_EQ(grouped_sharding.device_groups.size(), in AlignGroupsWith()
1502 reference.device_groups.size()); in AlignGroupsWith()
1504 for (int64 g = 0; g < reference.device_groups.size(); ++g) { in AlignGroupsWith()
[all …]
Dspmd_partitioner_util.h298 GroupedSharding(std::vector<std::vector<int64>> device_groups, in GroupedSharding()
302 : device_groups(std::move(device_groups)), in GroupedSharding()
307 std::vector<std::vector<int64>> device_groups; member
348 const std::vector<std::vector<int64>>& device_groups, SpmdBuilder* b);
354 const std::vector<std::vector<int64>>& device_groups,
400 const std::vector<std::vector<int64>>& device_groups);
Dspmd_partitioner.cc461 hlo_, state_.partition_id, target_grouped.device_groups, group_dims, in ReshardNoCache()
959 partitioned_hlo.state(), sharding_grouped.device_groups, in ReshardToPartialReplicateWithAllGather()
1560 MakePartitioningState(), grouped.device_groups, &b_); in HandleConcatenate()
2773 inputs[0].state(), grouped.device_groups, &b_); in HandleReduce()
3088 MakePartitioningState(), sharding_grouped.device_groups, &b_); in HandleRng()
3433 std::vector<ReplicaGroup> device_groups; in GetDefaultCollectiveOpsCreator() local
3434 device_groups.reserve(partition_subgroups.size() * num_replicas); in GetDefaultCollectiveOpsCreator()
3437 device_groups.emplace_back(); in GetDefaultCollectiveOpsCreator()
3439 device_groups.back().add_replica_ids(i * num_partitions + pid); in GetDefaultCollectiveOpsCreator()
3444 operand->shape(), {operand}, reduction, device_groups, in GetDefaultCollectiveOpsCreator()
[all …]
Ddot_handler.cc1728 lhs.state(), lhs_grouped.device_groups, b); in PartitionDotGroupOnBatch()
1743 lhs.state(), output_grouped.device_groups, b); in PartitionDotGroupOnBatch()
1756 output_grouped.device_groups, batch_dims, in PartitionDotGroupOnBatch()
1871 num_partitions / output_grouped.device_groups.size(), in PartitionDotGroupOnBatch()
1943 other_sharding, output_grouped.device_groups)) { in GetNonContractingPartitionGroupedShardingForOtherOperand()
2034 matching.state(), matching_grouped.device_groups, b); in PartitionDotGroupOnNonContracting()
2069 num_partitions / matching_grouped.device_groups.size(), in PartitionDotGroupOnNonContracting()
2192 output_sharding, lhs_grouped.device_groups)) { in PartitionDotGroupOnContracting()
2221 lhs.state(), lhs_grouped.device_groups, b); in PartitionDotGroupOnContracting()
2764 lhs.state(), grouped_output.device_groups, b); in PartitionDot()
Dgather_scatter_handler.cc418 operand.state(), grouped_operand.device_groups, b); in PartitionIndexParallelDimensions()
585 indices.state(), sharding_grouped.device_groups, &b_); in HandleScatter()