Home
last modified time | relevance | path

Searched refs:partition_subgroups (Results 1 – 4 of 4) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dfft_handler.cc164 std::vector<int64> partition_subgroups(num_partitions); in ShuffleDataWithAllToAll() local
165 std::iota(partition_subgroups.begin(), partition_subgroups.end(), 0); in ShuffleDataWithAllToAll()
166 groups[0] = partition_subgroups; in ShuffleDataWithAllToAll()
Dspmd_partitioner_util.cc1600 const std::vector<std::vector<int64>>& partition_subgroups) { in GetPerGroupCollectiveOpsCreator() argument
1601 if (partition_subgroups.empty()) { in GetPerGroupCollectiveOpsCreator()
1604 std::vector<std::vector<int64>> result(partition_subgroups.size() * in GetPerGroupCollectiveOpsCreator()
1607 for (int64 i = 0; i < partition_subgroups.size(); ++i) { in GetPerGroupCollectiveOpsCreator()
1608 result[g * partition_subgroups.size() + i].resize( in GetPerGroupCollectiveOpsCreator()
1609 partition_subgroups[i].size()); in GetPerGroupCollectiveOpsCreator()
1610 for (int64 j = 0; j < partition_subgroups[i].size(); ++j) { in GetPerGroupCollectiveOpsCreator()
1611 result[g * partition_subgroups.size() + i][j] = in GetPerGroupCollectiveOpsCreator()
1612 device_groups[g][partition_subgroups[i][j]]; in GetPerGroupCollectiveOpsCreator()
1621 const std::vector<std::vector<int64>>& partition_subgroups, in GetPerGroupCollectiveOpsCreator()
[all …]
Dspmd_partitioner.h120 const std::vector<std::vector<int64>>& partition_subgroups,
134 const std::vector<std::vector<int64>>& partition_subgroups,
142 const std::vector<std::vector<int64>>& partition_subgroups,
Dspmd_partitioner.cc3419 const std::vector<std::vector<int64>>& partition_subgroups, in GetDefaultCollectiveOpsCreator()
3421 if (partition_subgroups.size() <= 1) { in GetDefaultCollectiveOpsCreator()
3434 device_groups.reserve(partition_subgroups.size() * num_replicas); in GetDefaultCollectiveOpsCreator()
3436 for (const auto& pgroup : partition_subgroups) { in GetDefaultCollectiveOpsCreator()
3455 const std::vector<std::vector<int64>>& partition_subgroups, in GetDefaultCollectiveOpsCreator()
3461 std::vector<ReplicaGroup> groups(partition_subgroups.size()); in GetDefaultCollectiveOpsCreator()
3463 for (int64 id : partition_subgroups[i]) { in GetDefaultCollectiveOpsCreator()
3473 const std::vector<std::vector<int64>>& partition_subgroups, in GetDefaultCollectiveOpsCreator()
3476 device_groups.reserve(partition_subgroups.size() * num_replicas); in GetDefaultCollectiveOpsCreator()
3478 for (const auto& pgroup : partition_subgroups) { in GetDefaultCollectiveOpsCreator()
[all …]