Searched refs:partition_subgroups (Results 1 – 4 of 4) sorted by relevance
164 std::vector<int64> partition_subgroups(num_partitions); in ShuffleDataWithAllToAll() local165 std::iota(partition_subgroups.begin(), partition_subgroups.end(), 0); in ShuffleDataWithAllToAll()166 groups[0] = partition_subgroups; in ShuffleDataWithAllToAll()
1600 const std::vector<std::vector<int64>>& partition_subgroups) { in GetPerGroupCollectiveOpsCreator() argument1601 if (partition_subgroups.empty()) { in GetPerGroupCollectiveOpsCreator()1604 std::vector<std::vector<int64>> result(partition_subgroups.size() * in GetPerGroupCollectiveOpsCreator()1607 for (int64 i = 0; i < partition_subgroups.size(); ++i) { in GetPerGroupCollectiveOpsCreator()1608 result[g * partition_subgroups.size() + i].resize( in GetPerGroupCollectiveOpsCreator()1609 partition_subgroups[i].size()); in GetPerGroupCollectiveOpsCreator()1610 for (int64 j = 0; j < partition_subgroups[i].size(); ++j) { in GetPerGroupCollectiveOpsCreator()1611 result[g * partition_subgroups.size() + i][j] = in GetPerGroupCollectiveOpsCreator()1612 device_groups[g][partition_subgroups[i][j]]; in GetPerGroupCollectiveOpsCreator()1621 const std::vector<std::vector<int64>>& partition_subgroups, in GetPerGroupCollectiveOpsCreator()[all …]
120 const std::vector<std::vector<int64>>& partition_subgroups,134 const std::vector<std::vector<int64>>& partition_subgroups,142 const std::vector<std::vector<int64>>& partition_subgroups,
3419 const std::vector<std::vector<int64>>& partition_subgroups, in GetDefaultCollectiveOpsCreator()3421 if (partition_subgroups.size() <= 1) { in GetDefaultCollectiveOpsCreator()3434 device_groups.reserve(partition_subgroups.size() * num_replicas); in GetDefaultCollectiveOpsCreator()3436 for (const auto& pgroup : partition_subgroups) { in GetDefaultCollectiveOpsCreator()3455 const std::vector<std::vector<int64>>& partition_subgroups, in GetDefaultCollectiveOpsCreator()3461 std::vector<ReplicaGroup> groups(partition_subgroups.size()); in GetDefaultCollectiveOpsCreator()3463 for (int64 id : partition_subgroups[i]) { in GetDefaultCollectiveOpsCreator()3473 const std::vector<std::vector<int64>>& partition_subgroups, in GetDefaultCollectiveOpsCreator()3476 device_groups.reserve(partition_subgroups.size() * num_replicas); in GetDefaultCollectiveOpsCreator()3478 for (const auto& pgroup : partition_subgroups) { in GetDefaultCollectiveOpsCreator()[all …]