Searched refs:num_subdivs (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/core/common_runtime/ |
D | ring_alg.cc | 124 int num_subdivs = 0; in GenerateSubdivsInCollectiveParams() local 129 ++num_subdivs; in GenerateSubdivsInCollectiveParams() 130 int num_chunks = col_params->group.group_size * num_subdivs; in GenerateSubdivsInCollectiveParams() 132 VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks in GenerateSubdivsInCollectiveParams() 134 } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs); in GenerateSubdivsInCollectiveParams() 135 if (num_subdivs <= 0) { in GenerateSubdivsInCollectiveParams() 136 return errors::Internal("Unexpected num_subdivs ", num_subdivs, " in ", in GenerateSubdivsInCollectiveParams() 140 int subdiv_stride = kAvgDevPerTask / num_subdivs; in GenerateSubdivsInCollectiveParams() 142 col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs); in GenerateSubdivsInCollectiveParams() 143 for (int sdi = 0; sdi < num_subdivs; ++sdi) { in GenerateSubdivsInCollectiveParams() [all …]
|
D | hierarchical_tree_broadcaster.cc | 114 int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0); in InitializeCollectiveParams() local 118 col_params->instance.impl_details.subdiv_permutations.resize(num_subdivs); in InitializeCollectiveParams() 119 col_params->subdiv_rank.reserve(num_subdivs); in InitializeCollectiveParams() 120 col_params->instance.impl_details.subdiv_source_rank.reserve(num_subdivs); in InitializeCollectiveParams() 179 for (int sri = 0; sri < num_subdivs; sri++) { in InitializeCollectiveParams() 291 int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size()); in RunTree() local 297 for (int si = 0; si < num_subdivs; si++) { in RunTree() 359 if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) { in RunTree()
|
D | ring_reducer_test.cc | 156 const DeviceType& device_type, int num_subdivs, int fail_after) { in Init() argument 208 col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs); in Init() 209 col_params_.subdiv_rank.resize(num_subdivs); in Init() 210 int subdiv_stride = num_devices / num_subdivs; in Init() 211 for (int sdi = 0; sdi < num_subdivs; ++sdi) { in Init() 247 for (int sdi = 0; sdi < num_subdivs; ++sdi) { in Init() 285 int num_devices, int num_subdivs, int tensor_len, in RunTest() argument 287 Init(num_workers, num_devices, dtype, device_type, num_subdivs, fail_after); in RunTest() 419 int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size()); in DeviceInstance() local 428 for (int sdi = 0; sdi < num_subdivs; ++sdi) { in DeviceInstance() [all …]
|
D | ring_gatherer_test.cc | 134 const DeviceType& device_type, int num_subdivs, int fail_after) { in Init() argument 186 col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs); in Init() 187 col_params_.subdiv_rank.resize(num_subdivs); in Init() 188 int subdiv_stride = num_devices / num_subdivs; in Init() 189 for (int sdi = 0; sdi < num_subdivs; ++sdi) { in Init() 225 for (int sdi = 0; sdi < num_subdivs; ++sdi) { in Init() 263 int num_devices, int num_subdivs, int tensor_len, in RunTest() argument 265 Init(num_workers, num_devices, dtype, device_type, num_subdivs, fail_after); in RunTest() 395 int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size()); in DeviceInstance() local 404 for (int sdi = 0; sdi < num_subdivs; ++sdi) { in DeviceInstance()
|
D | hierarchical_tree_broadcaster_test.cc | 266 int num_subdivs = num_workers + (num_workers > 1 ? 1 : 0); in Init() local 267 VLOG(2) << "#subdiv=" << num_subdivs; in Init() 268 col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs); in Init() 269 col_params_.subdiv_rank.resize(num_subdivs); in Init() 293 for (int i = 0; subdiv_i < num_subdivs; i++, subdiv_i++) { in Init() 583 size_t num_subdivs = impl.subdiv_permutations.size(); in DeviceInstance() local 584 impl.subdiv_source_rank.resize(num_subdivs, 0); in DeviceInstance() 585 col_params_.subdiv_rank.resize(num_subdivs); in DeviceInstance() 586 for (size_t si = 0; si < num_subdivs; si++) { in DeviceInstance()
|