1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/strings/str_cat.h"
17 #include "absl/strings/str_format.h"
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 
20 #include <algorithm>
21 #include <random>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/device_factory.h"
25 #include "tensorflow/core/common_runtime/gpu/gpu_device.h"
26 #include "tensorflow/core/framework/tensor_testutil.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/nccl/nccl_manager.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/core/platform/unbounded_work_queue.h"
31 
32 namespace tensorflow {
33 
GetGPUDevices()34 static std::vector<std::unique_ptr<BaseGPUDevice>> GetGPUDevices() {
35   std::vector<std::unique_ptr<Device>> devices;
36   TF_CHECK_OK(DeviceFactory::GetFactory(DEVICE_GPU)
37                   ->AddDevices(SessionOptions(), "", &devices));
38   std::vector<std::unique_ptr<BaseGPUDevice>> gpus;
39   for (std::unique_ptr<Device>& device : devices) {
40     if (device->device_type() == "GPU") {
41       // If `device_type()` is GPU, this `Device` is guaranteed to be a
42       // `BaseGPUDevice`, which is a subclass of `Device`.
43       gpus.emplace_back(static_cast<BaseGPUDevice*>(device.release()));
44     }
45   }
46   return gpus;
47 }
48 
49 template <typename Scalar>
50 class NcclManagerTest : public ::testing::Test {
51  public:
52   // A single all-reduce to apply.
53   struct TestCase {
TestCasetensorflow::NcclManagerTest::TestCase54     TestCase(int num_nodes, int num_ranks_per_node)
55         : num_nodes(num_nodes), num_ranks_per_node(num_ranks_per_node) {}
56     std::vector<Tensor> ins;
57     std::vector<Tensor> outs;
58     Tensor expected;
59     const int num_nodes;
60     const int num_ranks_per_node;
61 
62     mutex mu;
63     Status final_status;
64     int num_completed TF_GUARDED_BY(mu) = 0;
65     condition_variable done_cv;
66   };
67 
SetUpTestSuite()68   static void SetUpTestSuite() {
69     setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
70     setenv("NCCL_LAUNCH_MODE", "PARALLEL", 1 /* replace */);
71     devices_ = new std::vector<std::unique_ptr<BaseGPUDevice>>(GetGPUDevices());
72     VLOG(1) << "Running test with " << devices_->size() << " gpus";
73     if (devices_->size() <= 1) {
74       LOG(FATAL) << "Cannot run NCCL test without multiple GPUs";
75     }
76     work_queue_ = new UnboundedWorkQueue(Env::Default(), "nccl_manager_test");
77   }
78 
SetUp()79   void SetUp() override {
80     ASSERT_GT(devices_->size(), 0) << "No GPUs found";
81     ASSERT_NE(work_queue_, nullptr);
82   }
83 
NumGPUs()84   static int32 NumGPUs() { return static_cast<int32>(devices_->size()); }
85 
86   // Let N = #GPUs.  When N is even, num_nodes=2 and num_ranks_per_node=N/2.
87   // When N is odd, num_nodes=2 and num_ranks_per_node=(N-1)/2.
PopulateMultiNodeParams(int * num_nodes,int * num_ranks_per_node)88   static void PopulateMultiNodeParams(int* num_nodes, int* num_ranks_per_node) {
89     const auto num_gpus = NumGPUs();
90     CHECK_GT(num_gpus, 1);
91     *num_nodes = 2;
92     if (num_gpus % 2 == 0) {
93       *num_ranks_per_node = num_gpus / 2;
94     } else {
95       *num_ranks_per_node = (num_gpus - 1) / 2;
96     }
97   }
98 
TearDownTestSuite()99   static void TearDownTestSuite() {
100     delete devices_;
101     delete work_queue_;
102   }
103 
MakeReductionTestCase(int num_nodes,int num_ranks_per_node,ncclRedOp_t reduction_op,TensorShape shape,float value_offset)104   TestCase* MakeReductionTestCase(int num_nodes, int num_ranks_per_node,
105                                   ncclRedOp_t reduction_op, TensorShape shape,
106                                   float value_offset) {
107     TestCase* test_case = new TestCase(num_nodes, num_ranks_per_node);
108     test_case->expected = Tensor(data_type_, shape);
109     if (reduction_op == ncclProd) {
110       test::FillFn<Scalar>(&test_case->expected,
111                            [](int) { return static_cast<Scalar>(1); });
112     } else if (reduction_op == ncclSum) {
113       test::FillFn<Scalar>(&test_case->expected,
114                            [](int) { return static_cast<Scalar>(0); });
115     } else if (reduction_op == ncclMax) {
116       test::FillFn<Scalar>(&test_case->expected, [](int) { return -max_; });
117     } else if (reduction_op == ncclMin) {
118       test::FillFn<Scalar>(&test_case->expected, [](int) { return max_; });
119     } else {
120       LOG(FATAL) << "Invalid reduction_op " << reduction_op;
121     }
122 
123     float value_scale = 0.01;  // Small scale to avoid fp16 overflow.
124     for (int node = 0; node < num_nodes; ++node) {
125       for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) {
126         auto* device = GetDevice(num_ranks_per_node, node, local_rank);
127         auto* stream = device->tensorflow_gpu_device_info()->stream;
128 
129         Tensor in_cpu(data_type_, shape);
130         test::FillFn<Scalar>(&in_cpu, [&](int index) {
131           return static_cast<Scalar>((index + 1) * value_scale + value_offset);
132         });
133         for (int j = 0; j < shape.num_elements(); ++j) {
134           auto in_val = in_cpu.flat<Scalar>()(j);
135           auto out_expr = test_case->expected.template flat<Scalar>();
136           if (reduction_op == ncclProd) {
137             out_expr(j) = out_expr(j) * in_val;
138           } else if (reduction_op == ncclSum) {
139             out_expr(j) = out_expr(j) + in_val;
140           } else if (reduction_op == ncclMax) {
141             if (in_val > out_expr(j)) {
142               out_expr(j) = in_val;
143             }
144           } else if (reduction_op == ncclMin) {
145             if (in_val < out_expr(j)) {
146               out_expr(j) = in_val;
147             }
148           }
149         }
150 
151         value_scale *= 10;
152         test_case->ins.emplace_back(GpuAllocator(device), data_type_, shape);
153         test_case->outs.emplace_back(GpuAllocator(device), data_type_, shape);
154 
155         const Tensor& in_gpu = test_case->ins.back();
156         auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<Scalar>().data());
157         stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<Scalar>().data(),
158                            in_cpu.TotalBytes());
159       }
160     }
161 
162     return test_case;
163   }
164 
MakeGatherTestCase(int num_nodes,int num_ranks_per_node,TensorShape in_shape,TensorShape out_shape)165   TestCase* MakeGatherTestCase(int num_nodes, int num_ranks_per_node,
166                                TensorShape in_shape, TensorShape out_shape) {
167     TestCase* test_case = new TestCase(num_nodes, num_ranks_per_node);
168     test_case->expected = Tensor(data_type_, out_shape);
169     test::FillFn<Scalar>(&test_case->expected,
170                          [](int) { return static_cast<Scalar>(0); });
171 
172     float value_scale = 0.01;  // Small scale to avoid fp16 overflow.
173     for (int node = 0; node < num_nodes; ++node) {
174       for (int i = 0; i < num_ranks_per_node; ++i) {
175         auto* device = GetDevice(num_ranks_per_node, node, i);
176         auto* stream = device->tensorflow_gpu_device_info()->stream;
177 
178         Tensor in_cpu(data_type_, in_shape);
179         test::FillFn<Scalar>(&in_cpu, [&](int index) {
180           return static_cast<Scalar>((index + 1) * value_scale);
181         });
182         // Starting index for this rank's tensor in the all-gathered output.
183         int32 gather_idx =
184             (node * num_ranks_per_node + i) * in_shape.num_elements();
185         for (int j = 0; j < in_shape.num_elements(); ++j) {
186           auto in_val = in_cpu.flat<Scalar>()(j);
187           auto out_expr = test_case->expected.template flat<Scalar>();
188           out_expr(gather_idx + j) = in_val;
189         }
190 
191         value_scale *= 10;
192         test_case->ins.emplace_back(GpuAllocator(device), data_type_, in_shape);
193         test_case->outs.emplace_back(GpuAllocator(device), data_type_,
194                                      out_shape);
195 
196         const Tensor& in_gpu = test_case->ins.back();
197         auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<Scalar>().data());
198         stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<Scalar>().data(),
199                            in_cpu.TotalBytes());
200       }
201     }
202 
203     return test_case;
204   }
205 
206   // Make a broadcast test which broadcasts a tensor with shape `shape` from
207   // `src_node`, `src_rank` to all other ranks.
208   // If `in_place` is true, input and output are the same for the source,
209   // otherwise they are tensors backed by different buffers.
MakeBroadcastTestCase(int num_nodes,int num_ranks_per_node,TensorShape shape,int src_node,int src_rank,bool in_place)210   TestCase* MakeBroadcastTestCase(int num_nodes, int num_ranks_per_node,
211                                   TensorShape shape, int src_node, int src_rank,
212                                   bool in_place) {
213     TestCase* test_case = new TestCase(num_nodes, num_ranks_per_node);
214     test_case->expected = Tensor(data_type_, shape);
215     test::FillFn<Scalar>(&test_case->expected,
216                          [](int) { return static_cast<Scalar>(1); });
217 
218     for (int node = 0; node < num_nodes; ++node) {
219       for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) {
220         auto* device = GetDevice(num_ranks_per_node, node, local_rank);
221         if (node == src_node && local_rank == src_rank) {
222           test_case->ins.emplace_back(GpuAllocator(device), data_type_, shape);
223           if (in_place) {
224             test_case->outs.emplace_back(test_case->ins.back());
225           } else {
226             test_case->outs.emplace_back(GpuAllocator(device), data_type_,
227                                          shape);
228           }
229           Tensor in_cpu(data_type_, shape);
230           test::FillFn<Scalar>(&in_cpu,
231                                [](int) { return static_cast<Scalar>(1); });
232           const Tensor& in_gpu = test_case->ins.back();
233           auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<Scalar>().data());
234           auto* stream = device->tensorflow_gpu_device_info()->stream;
235           stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<Scalar>().data(),
236                              in_cpu.TotalBytes());
237         } else {
238           test_case->ins.emplace_back(Tensor());
239           test_case->outs.emplace_back(GpuAllocator(device), data_type_, shape);
240         }
241       }
242     }
243 
244     return test_case;
245   }
246 
247   // Waits for the done callback to be called for each participant.
WaitForTestCompletion(TestCase * test_case)248   void WaitForTestCompletion(TestCase* test_case) {
249     mutex_lock l(test_case->mu);
250     while (test_case->num_completed != test_case->outs.size()) {
251       test_case->done_cv.wait(l);
252     }
253   }
254 
VerifyResults(TestCase * test_case)255   void VerifyResults(TestCase* test_case) {
256     WaitForTestCompletion(test_case);
257     TF_ASSERT_OK(test_case->final_status);
258     // Copy memory to host and verify.
259     for (int node = 0; node < test_case->num_nodes; ++node) {
260       for (int local_rank = 0; local_rank < test_case->num_ranks_per_node;
261            ++local_rank) {
262         auto* device =
263             GetDevice(test_case->num_ranks_per_node, node, local_rank);
264         auto* stream = device->tensorflow_gpu_device_info()->stream;
265         const int global_rank =
266             GlobalRank(test_case->num_ranks_per_node, node, local_rank);
267         const Tensor& out_gpu = test_case->outs[global_rank];
268         Tensor out_cpu(data_type_, out_gpu.shape());
269         auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<Scalar>().data());
270         stream->ThenMemcpy(out_cpu.flat<Scalar>().data(), out_gpu_mem,
271                            out_cpu.TotalBytes());
272         SE_ASSERT_OK(stream->BlockHostUntilDone());
273         VLOG(1) << "Verifying rank " << global_rank << " expected shape "
274                 << test_case->expected.shape() << " out shape "
275                 << out_cpu.shape();
276         test::ExpectClose(test_case->expected, out_cpu);
277       }
278     }
279   }
280 
VerifyError(TestCase * test_case)281   void VerifyError(TestCase* test_case) {
282     WaitForTestCompletion(test_case);
283     LOG(INFO) << test_case->final_status;
284     EXPECT_EQ(test_case->final_status.code(), error::INTERNAL);
285   }
286 
CreateDoneCallback(TestCase * test_case)287   NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
288     return [this, test_case](Status s) {
289       mutex_lock l(test_case->mu);
290       test_case->final_status.Update(s);
291       if (++test_case->num_completed == test_case->outs.size()) {
292         test_case->done_cv.notify_one();
293       }
294     };
295   }
296 
297   struct NodeState {
298     NcclManager nccl_manager;
299     std::atomic<int> launched{0};
300   };
301 
RunMultiNodeAllReduceTest(const int num_nodes,const int num_ranks_per_node)302   void RunMultiNodeAllReduceTest(const int num_nodes,
303                                  const int num_ranks_per_node) {
304     std::vector<NodeState> node_states(num_nodes);
305     RunMultiNodeAllReduceTest(node_states, num_ranks_per_node);
306   }
307 
RunMultiNodeAllReduceTest(std::vector<NodeState> & node_states,const int num_ranks_per_node)308   void RunMultiNodeAllReduceTest(std::vector<NodeState>& node_states,
309                                  const int num_ranks_per_node) {
310     const int num_nodes = node_states.size();
311     const int num_global_ranks = num_nodes * num_ranks_per_node;
312     const string collective_key = "allreduce";
313     // The NcclManagers in this test synchronize in real-time, so we need to run
314     // each node's code in a separate thread.
315     // Specifically, the call to ncclGroupEnd() after calling ncclCommInitRank
316     // waits for all communicators before returning.
317 
318     // First, initialize the communicator_key used for this collective.
319     const string communicator_key =
320         node_states[0].nccl_manager.GenerateCommunicatorKey();
321 
322     for (int op = 0; op < 4; ++op) {
323       ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op);
324       std::unique_ptr<TestCase> test_case(
325           this->MakeReductionTestCase(num_nodes, num_ranks_per_node,
326                                       reduction_op, TensorShape({2, 3}), 0.0f));
327       for (int node = 0; node < num_nodes; ++node) {
328         auto node_fn = [this, node, num_ranks_per_node, num_global_ranks,
329                         &node_states, &communicator_key, &collective_key,
330                         reduction_op, &test_case] {
331           for (int local_rank = 0; local_rank < num_ranks_per_node;
332                ++local_rank) {
333             auto* device = GetDevice(num_ranks_per_node, node, local_rank);
334             auto* info = device->tensorflow_gpu_device_info();
335             auto* stream = device->tensorflow_gpu_device_info()->stream;
336             const int global_rank =
337                 GlobalRank(num_ranks_per_node, node, local_rank);
338             auto participant = absl::make_unique<NcclManager::Participant>(
339                 device->executor(), stream, info, &test_case->ins[global_rank],
340                 &test_case->outs[global_rank], global_rank,
341                 this->CreateDoneCallback(test_case.get()));
342             node_states[node].nccl_manager.AddToAllReduce(
343                 std::move(participant),
344                 {collective_key, num_ranks_per_node, num_global_ranks,
345                  communicator_key, /*source_rank=*/-1},
346                 reduction_op);
347             VLOG(1) << "AddToAllReduce node " << node << " global_rank "
348                     << global_rank;
349           }
350 
351           // Signal collective ready to launch at this node.
352           node_states[node].nccl_manager.SignalMultiNodeReady(collective_key);
353         };
354         this->work_queue_->Schedule(node_fn);
355       }
356 
357       VLOG(2) << "Verifying results";
358       this->VerifyResults(test_case.get());
359     }
360   }
361 
RunMultiNodeBroadcastTest(const int num_nodes,const int num_ranks_per_node,const int src_node,const int src_local_rank,const bool in_place)362   void RunMultiNodeBroadcastTest(const int num_nodes,
363                                  const int num_ranks_per_node,
364                                  const int src_node, const int src_local_rank,
365                                  const bool in_place) {
366     const int num_global_ranks = num_nodes * num_ranks_per_node;
367     const int src_global_rank = src_node * num_ranks_per_node + src_local_rank;
368     const string collective_key = "broadcast";
369     std::vector<NodeState> node_states(num_nodes);
370     const string communicator_key =
371         node_states[0].nccl_manager.GenerateCommunicatorKey();
372     std::unique_ptr<TestCase> test_case(this->MakeBroadcastTestCase(
373         num_nodes, num_ranks_per_node, TensorShape({5, 6}), src_node,
374         src_local_rank, in_place));
375     for (int node = 0; node < num_nodes; ++node) {
376       for (int local_rank = 0; local_rank < num_ranks_per_node; ++local_rank) {
377         // Launch each rank in a separate thread to test concurrent,
378         // randomly-ordered calls into NcclManager.
379         auto rank_fn = [this, node, num_ranks_per_node, num_global_ranks,
380                         src_global_rank, local_rank, &node_states,
381                         &collective_key, &communicator_key, &test_case]() {
382           auto* device = GetDevice(num_ranks_per_node, node, local_rank);
383           auto* info = device->tensorflow_gpu_device_info();
384           auto* stream = device->tensorflow_gpu_device_info()->stream;
385           const int global_rank =
386               GlobalRank(num_ranks_per_node, node, local_rank);
387           auto* input = global_rank == src_global_rank
388                             ? &test_case->ins[global_rank]
389                             : nullptr;
390           auto* output = test_case->outs[global_rank].NumElements() == 0
391                              ? nullptr
392                              : &test_case->outs[global_rank];
393           auto participant = absl::make_unique<NcclManager::Participant>(
394               device->executor(), stream, info, input, output, global_rank,
395               this->CreateDoneCallback(test_case.get()));
396           if (global_rank == src_global_rank) {
397             node_states[node].nccl_manager.AddBroadcastSend(
398                 std::move(participant),
399                 {collective_key, num_ranks_per_node, num_global_ranks,
400                  communicator_key, src_global_rank});
401           } else {
402             node_states[node].nccl_manager.AddBroadcastRecv(
403                 std::move(participant),
404                 {collective_key, num_ranks_per_node, num_global_ranks,
405                  communicator_key, src_global_rank});
406           }
407 
408           if (++node_states[node].launched == num_ranks_per_node) {
409             // Signal collective ready to launch at this node.
410             node_states[node].nccl_manager.SignalMultiNodeReady(collective_key);
411           }
412         };
413         this->work_queue_->Schedule(std::move(rank_fn));
414       }
415     }
416 
417     VLOG(2) << "Verifying results";
418     this->VerifyResults(test_case.get());
419   }
420 
GlobalRank(int num_ranks_per_node,int node,int local_rank)421   static int GlobalRank(int num_ranks_per_node, int node, int local_rank) {
422     return node * num_ranks_per_node + local_rank;
423   }
424 
GetDevice(int num_ranks_per_node,int node,int local_rank)425   static BaseGPUDevice* GetDevice(int num_ranks_per_node, int node,
426                                   int local_rank) {
427     const int device_idx = GlobalRank(num_ranks_per_node, node, local_rank);
428     CHECK_LT(device_idx, devices_->size());
429     return (*devices_)[device_idx].get();
430   }
431 
432   static UnboundedWorkQueue* work_queue_;
433 
434  private:
GpuAllocator(BaseGPUDevice * device)435   static Allocator* GpuAllocator(BaseGPUDevice* device) {
436     return device->GetAllocator(AllocatorAttributes());
437   }
438 
AsDeviceMemory(const Scalar * cuda_memory)439   static se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
440     se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory));
441     se::DeviceMemory<Scalar> typed(wrapped);
442     return typed;
443   }
444 
445   static std::vector<std::unique_ptr<BaseGPUDevice>>* devices_;
446   static const DataType data_type_;
447   static const Scalar max_;
448 };
449 
450 template <typename Scalar>
451 std::vector<std::unique_ptr<BaseGPUDevice>>* NcclManagerTest<Scalar>::devices_ =
452     nullptr;
453 template <typename Scalar>
454 const DataType NcclManagerTest<Scalar>::data_type_ =
455     DataTypeToEnum<Scalar>::value;
456 template <typename Scalar>
457 const Scalar NcclManagerTest<Scalar>::max_ =
458     Eigen::NumTraits<Scalar>::highest();
459 template <typename Scalar>
460 UnboundedWorkQueue* NcclManagerTest<Scalar>::work_queue_ = nullptr;
461 
462 // Instantiate tests for float and double.
463 using TypeList = ::testing::Types<float, double>;
464 TYPED_TEST_SUITE(NcclManagerTest, TypeList);
465 
466 // Test basic sum reduction.
TYPED_TEST(NcclManagerTest,BasicSumReduction)467 TYPED_TEST(NcclManagerTest, BasicSumReduction) {
468   const int num_ranks = this->NumGPUs();
469 
470   for (int op = 0; op < 4; ++op) {
471     ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op);
472     std::unique_ptr<typename TestFixture::TestCase> test_case(
473         this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, reduction_op,
474                                     TensorShape({2, 3}), 0.0f));
475     for (int rank = 0; rank < num_ranks; ++rank) {
476       auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
477       VLOG(2) << "rank " << rank << " device " << device->name();
478       auto* info = device->tensorflow_gpu_device_info();
479       auto* stream = device->tensorflow_gpu_device_info()->stream;
480       auto participant = absl::make_unique<NcclManager::Participant>(
481           device->executor(), stream, info, &test_case->ins[rank],
482           &test_case->outs[rank], /*global_rank=*/-1,
483           this->CreateDoneCallback(test_case.get()));
484       NcclManager::instance()->AddToAllReduce(
485           std::move(participant),
486           {"allreduce", /*num_local_devices=*/num_ranks,
487            /*num_global_devices=*/num_ranks, /*communicator_key=*/"",
488            /*source_rank=*/-1},
489           reduction_op);
490     }
491 
492     LOG(INFO) << "Verifying results";
493     this->VerifyResults(test_case.get());
494   }
495 }
496 
497 // Same as the Basic test, but with multiple threads launching parts of many
498 // reductions.
499 //
500 // To run test longer, increase num_ranks, num_collectives_per_iteration and
501 // time_limit_micros.
TYPED_TEST(NcclManagerTest,MultipleCallers)502 TYPED_TEST(NcclManagerTest, MultipleCallers) {
503   const int num_ranks = this->NumGPUs();
504   const int num_collectives_per_iteration = 10;
505   const int time_limit_micros = 1 * 1000 * 1000;  // 1 second
506 
507   int64 start = Env::Default()->NowMicros();
508   srand(Env::Default()->NowMicros());
509 
510   for (;;) {
511     std::vector<std::pair<int, int>> case_and_rank;
512     std::vector<std::unique_ptr<typename TestFixture::TestCase>> test_cases;
513     for (int i = 0; i < num_collectives_per_iteration; ++i) {
514       test_cases.emplace_back(this->MakeReductionTestCase(
515           /*num_nodes=*/1, num_ranks, ncclSum,
516           TensorShape({100, i % 5 + 1, i % 3 + 1}), 1.1f * i));
517       for (int j = 0; j < num_ranks; ++j) {
518         case_and_rank.emplace_back(i, j);
519       }
520     }
521 
522     for (int rank = 0; rank < num_ranks; ++rank) {
523       auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
524       auto* stream = device->tensorflow_gpu_device_info()->stream;
525       SE_ASSERT_OK(stream->BlockHostUntilDone());
526     }
527 
528     std::shuffle(case_and_rank.begin(), case_and_rank.end(),
529                  std::mt19937(std::random_device()()));
530 
531     mutex mu;  // guards case_and_rank.
532     const int to_schedule = case_and_rank.size();
533     for (int i = 0; i < to_schedule; ++i) {
534       auto fn = [&]() {
535         int rank;
536         int test_num;
537         {
538           mutex_lock l(mu);
539           test_num = case_and_rank.back().first;
540           rank = case_and_rank.back().second;
541           case_and_rank.pop_back();
542         }
543         auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
544         auto* info = device->tensorflow_gpu_device_info();
545         auto* stream = device->tensorflow_gpu_device_info()->stream;
546         typename TestFixture::TestCase* test_case = test_cases[test_num].get();
547         auto participant = absl::make_unique<NcclManager::Participant>(
548             device->executor(), stream, info, &test_case->ins[rank],
549             &test_case->outs[rank], /*global_rank=*/-1,
550             this->CreateDoneCallback(test_case));
551         NcclManager::instance()->AddToAllReduce(
552             std::move(participant),
553             {strings::StrCat("allreduce", test_num),
554              /*num_local_devices=*/num_ranks,
555              /*num_global_devices=*/num_ranks,
556              /*communicator_key=*/"", /*source_rank=*/-1},
557             ncclSum);
558       };
559       this->work_queue_->Schedule(fn);
560     }
561 
562     VLOG(2) << "Verifying results for " << num_collectives_per_iteration
563             << " collectives";
564     for (int i = 0; i < test_cases.size(); ++i) {
565       this->VerifyResults(test_cases[i].get());
566     }
567 
568     int64 delta = Env::Default()->NowMicros() - start;
569     if (delta > time_limit_micros) {
570       LOG(INFO) << "Ran for " << delta << " microsecs, now quitting";
571       break;
572     }
573   }
574 }
575 
576 // Test basic all-gather.
TYPED_TEST(NcclManagerTest,BasicAllGather)577 TYPED_TEST(NcclManagerTest, BasicAllGather) {
578   const int num_ranks = this->NumGPUs();
579   for (int i = 0; i < num_ranks; ++i) {
580     std::unique_ptr<typename TestFixture::TestCase> test_case(
581         this->MakeGatherTestCase(/*num_nodes=*/1, num_ranks,
582                                  TensorShape({2, 3}),
583                                  TensorShape({2 * num_ranks, 3})));
584     for (int rank = 0; rank < num_ranks; ++rank) {
585       auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
586       VLOG(2) << "rank " << rank << " device " << device->name();
587       auto* info = device->tensorflow_gpu_device_info();
588       auto* stream = device->tensorflow_gpu_device_info()->stream;
589       auto participant = absl::make_unique<NcclManager::Participant>(
590           device->executor(), stream, info, &test_case->ins[rank],
591           &test_case->outs[rank], rank,
592           this->CreateDoneCallback(test_case.get()));
593       NcclManager::instance()->AddToAllGather(
594           std::move(participant),
595           {"allgather", /*num_local_devices=*/num_ranks,
596            /*num_global_devices=*/num_ranks, /*communicator_key=*/"",
597            /*source_rank=*/-1});
598     }
599 
600     LOG(INFO) << "Verifying results";
601     this->VerifyResults(test_case.get());
602   }
603 }
604 
605 // Test basic broadcast.
TYPED_TEST(NcclManagerTest,BasicBroadcast)606 TYPED_TEST(NcclManagerTest, BasicBroadcast) {
607   this->RunMultiNodeBroadcastTest(/*num_nodes=*/1,
608                                   /*num_ranks_per_node=*/this->NumGPUs(),
609                                   /*src_node=*/0, /*src_local_rank=*/0,
610                                   /*in_place=*/false);
611 }
612 
613 // Test in-place broadcast.
TYPED_TEST(NcclManagerTest,InPlaceBroadcast)614 TYPED_TEST(NcclManagerTest, InPlaceBroadcast) {
615   this->RunMultiNodeBroadcastTest(/*num_nodes=*/1,
616                                   /*num_ranks_per_node=*/this->NumGPUs(),
617                                   /*src_node=*/0, /*src_local_rank=*/0,
618                                   /*in_place=*/true);
619 }
620 
621 // Test broadcast with increasing ranks.
TYPED_TEST(NcclManagerTest,BroadcastWithDifferentRanks)622 TYPED_TEST(NcclManagerTest, BroadcastWithDifferentRanks) {
623   for (int num_ranks = 1; num_ranks <= this->NumGPUs(); ++num_ranks) {
624     const int src_rank = static_cast<int>(random::New64() % num_ranks);
625     for (int in_place_idx = 0; in_place_idx <= 1; ++in_place_idx) {
626       const bool in_place = in_place_idx == 0;
627       this->RunMultiNodeBroadcastTest(/*num_nodes=*/1, num_ranks,
628                                       /*src_node=*/0, src_rank, in_place);
629     }
630   }
631 }
632 
633 // Multi-node NCCL tests.
634 
TEST(NcclManagerTest,CommunicatorKey)635 TEST(NcclManagerTest, CommunicatorKey) {
636   const string communicator_key =
637       NcclManager::instance()->GenerateCommunicatorKey();
638   EXPECT_EQ(communicator_key.size(), NCCL_UNIQUE_ID_BYTES);
639 }
640 
641 #if !TENSORFLOW_USE_ROCM
642 // ROCm platform currently does not support simulating a multi-node
643 // environment, on a single node with multiple GPUS. So tests that rely
644 // upon such simulation need to be skipped on the ROCm platform
645 
646 // This test creates `num_nodes` NcclManagers to simulate a multi-node
647 // environment.  It works on a single node with multiple GPUs.  It enqueues NCCL
648 // kernels on separate stream per rank.
TYPED_TEST(NcclManagerTest,MultiNode)649 TYPED_TEST(NcclManagerTest, MultiNode) {
650   int num_nodes;
651   int num_ranks_per_node;
652   this->PopulateMultiNodeParams(&num_nodes, &num_ranks_per_node);
653   VLOG(1) << "Calling RunMultiNodeAllReduceTest with num_nodes=" << num_nodes
654           << " and num_ranks_per_node=" << num_ranks_per_node;
655   this->RunMultiNodeAllReduceTest(num_nodes, num_ranks_per_node);
656 }
657 #endif
658 
659 // Tests that specifying `communicator_key` with a single node NCCL collective
660 // works well.
TYPED_TEST(NcclManagerTest,MultiNodeSingle)661 TYPED_TEST(NcclManagerTest, MultiNodeSingle) {
662   this->RunMultiNodeAllReduceTest(/*num_nodes=*/1,
663                                   /*num_ranks_per_node=*/this->NumGPUs());
664 }
665 
666 #if !TENSORFLOW_USE_ROCM
667 // ROCm platform currently does not support simulating a multi-node
668 // environment, on a single node with multiple GPUS. So tests that rely
669 // upon such simulation need to be skipped on the ROCm platform
670 
671 // Multi-node broadcast.
TYPED_TEST(NcclManagerTest,MultiNodeBroadcast)672 TYPED_TEST(NcclManagerTest, MultiNodeBroadcast) {
673   int num_nodes;
674   int num_ranks_per_node;
675   this->PopulateMultiNodeParams(&num_nodes, &num_ranks_per_node);
676   VLOG(1) << "Calling RunMultiNodeBroadcastTest with num_nodes=" << num_nodes
677           << " and num_ranks_per_node=" << num_ranks_per_node;
678   this->RunMultiNodeBroadcastTest(num_nodes, num_ranks_per_node,
679                                   /*src_node=*/0, /*src_local_rank=*/0,
680                                   /*in_place=*/true);
681 }
682 #endif
683 
684 // Checks that we return error status if a collective_key is used for different
685 // types of collectives, e.g.a reduction and a broadcast.
TYPED_TEST(NcclManagerTest,ConsistentCollectiveType)686 TYPED_TEST(NcclManagerTest, ConsistentCollectiveType) {
687   const int num_ranks = 2;
688 
689   std::unique_ptr<typename TestFixture::TestCase> test_case(
690       this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, ncclSum,
691                                   TensorShape({2, 3}), 0.0f));
692   for (int rank = 0; rank < num_ranks; ++rank) {
693     auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
694     auto* info = device->tensorflow_gpu_device_info();
695     auto* stream = device->tensorflow_gpu_device_info()->stream;
696     auto participant = absl::make_unique<NcclManager::Participant>(
697         device->executor(), stream, info, &test_case->ins[rank],
698         &test_case->outs[rank], /*global_rank=*/-1,
699         this->CreateDoneCallback(test_case.get()));
700     if (rank == 0) {
701       NcclManager::instance()->AddToAllReduce(std::move(participant),
702                                               {"bad_coll_type",
703                                                /*num_local_devices=*/num_ranks,
704                                                /*num_global_devices=*/num_ranks,
705                                                /*communicator_key=*/"",
706                                                /*source_rank=*/-1},
707                                               ncclSum);
708     } else {
709       NcclManager::instance()->AddBroadcastSend(
710           std::move(participant),
711           {"bad_coll_type",
712            /*num_local_devices=*/num_ranks,
713            /*num_global_devices=*/num_ranks,
714            /*communicator_key=*/"", /*source_rank=*/-1});
715     }
716   }
717 
718   this->VerifyError(test_case.get());
719 }
720 
721 // Checks that we return error status if different communicator_key is passed to
722 // same collective.
TYPED_TEST(NcclManagerTest,ConsistentCommunicatorKey)723 TYPED_TEST(NcclManagerTest, ConsistentCommunicatorKey) {
724   const int num_ranks = 2;
725 
726   std::unique_ptr<typename TestFixture::TestCase> test_case(
727       this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, ncclSum,
728                                   TensorShape({2, 3}), 0.0f));
729   for (int rank = 0; rank < num_ranks; ++rank) {
730     auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
731     auto* info = device->tensorflow_gpu_device_info();
732     auto* stream = device->tensorflow_gpu_device_info()->stream;
733     auto participant = absl::make_unique<NcclManager::Participant>(
734         device->executor(), stream, info, &test_case->ins[rank],
735         &test_case->outs[rank], /*global_rank=*/-1,
736         this->CreateDoneCallback(test_case.get()));
737     NcclManager::instance()->AddToAllReduce(
738         std::move(participant),
739         {"bad_coll_type",
740          /*num_local_devices=*/num_ranks,
741          /*num_global_devices=*/num_ranks,
742          rank == 0 ? "" : NcclManager::instance()->GenerateCommunicatorKey(),
743          /*source_rank=*/-1},
744         ncclSum);
745   }
746 
747   this->VerifyError(test_case.get());
748 }
749 
750 // Checks that we return error status if the number of devices is inconsistent
751 // across multiple participants of a collective.
TYPED_TEST(NcclManagerTest,ConsistentNumberOfDevices)752 TYPED_TEST(NcclManagerTest, ConsistentNumberOfDevices) {
753   const int num_ranks = 2;
754 
755   std::unique_ptr<typename TestFixture::TestCase> test_case(
756       this->MakeReductionTestCase(/*num_nodes=*/1, num_ranks, ncclSum,
757                                   TensorShape({2, 3}), 0.0f));
758   for (int rank = 0; rank < num_ranks; ++rank) {
759     auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
760     auto* info = device->tensorflow_gpu_device_info();
761     auto* stream = device->tensorflow_gpu_device_info()->stream;
762     int num_devices = rank == 0 ? num_ranks : num_ranks + 1;
763     auto participant = absl::make_unique<NcclManager::Participant>(
764         device->executor(), stream, info, &test_case->ins[rank],
765         &test_case->outs[rank], /*global_rank=*/-1,
766         this->CreateDoneCallback(test_case.get()));
767     NcclManager::instance()->AddToAllReduce(std::move(participant),
768                                             {"bad_coll_type",
769                                              /*num_local_devices=*/num_devices,
770                                              /*num_global_devices=*/num_devices,
771                                              /*communicator_key=*/"",
772                                              /*source_rank=*/-1},
773                                             ncclSum);
774   }
775 
776   this->VerifyError(test_case.get());
777 }
778 
779 // Checks that we return error status if a broadcast does not have source.
TYPED_TEST(NcclManagerTest,BroadcastNoSource)780 TYPED_TEST(NcclManagerTest, BroadcastNoSource) {
781   const int num_ranks = 2;
782 
783   std::unique_ptr<typename TestFixture::TestCase> test_case(
784       this->MakeBroadcastTestCase(/*num_nodes=*/1, num_ranks,
785                                   TensorShape({2, 3}), /*src_node=*/-1,
786                                   /*src_rank=*/-1, false));
787   for (int rank = 0; rank < num_ranks; ++rank) {
788     auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
789     auto* info = device->tensorflow_gpu_device_info();
790     auto* stream = device->tensorflow_gpu_device_info()->stream;
791     auto participant = absl::make_unique<NcclManager::Participant>(
792         device->executor(), stream, info, nullptr, &test_case->outs[rank], rank,
793         this->CreateDoneCallback(test_case.get()));
794     NcclManager::instance()->AddBroadcastRecv(std::move(participant),
795                                               {"bcast_no_send",
796                                                /*num_local_devices=*/num_ranks,
797                                                /*num_global_devices=*/num_ranks,
798                                                /*communicator_key=*/"",
799                                                /*source_rank=*/-1});
800   }
801 
802   this->VerifyError(test_case.get());
803 }
804 
805 // Checks that we return error status if a broadcast has multiple sends.
TYPED_TEST(NcclManagerTest,BroadcastMultipleSends)806 TYPED_TEST(NcclManagerTest, BroadcastMultipleSends) {
807   const int num_ranks = 2;
808 
809   std::unique_ptr<typename TestFixture::TestCase> test_case(
810       this->MakeBroadcastTestCase(/*num_nodes=*/1, num_ranks,
811                                   TensorShape({2, 3}), /*src_node=*/-1,
812                                   /*src_rank=*/-1, false));
813   for (int rank = 0; rank < num_ranks; ++rank) {
814     auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
815     auto* info = device->tensorflow_gpu_device_info();
816     auto* stream = device->tensorflow_gpu_device_info()->stream;
817     auto participant = absl::make_unique<NcclManager::Participant>(
818         device->executor(), stream, info, &test_case->outs[rank],
819         &test_case->outs[rank], rank,
820         this->CreateDoneCallback(test_case.get()));
821     NcclManager::instance()->AddBroadcastSend(std::move(participant),
822                                               {"bcast_multiple_send",
823                                                /*num_local_devices=*/num_ranks,
824                                                /*num_global_devices=*/num_ranks,
825                                                /*communicator_key=*/"",
826                                                /*source_rank=*/-1});
827   }
828 
829   this->VerifyError(test_case.get());
830 }
831 
832 // Checks that we return error status if a broadcast has inconsistent source
833 // ranks.
TYPED_TEST(NcclManagerTest,BroadcastInconsistentSource)834 TYPED_TEST(NcclManagerTest, BroadcastInconsistentSource) {
835   const int num_ranks = 2;
836 
837   std::unique_ptr<typename TestFixture::TestCase> test_case(
838       this->MakeBroadcastTestCase(/*num_nodes=*/1, num_ranks,
839                                   TensorShape({2, 3}), /*src_node=*/-1,
840                                   /*src_rank=*/-1, false));
841   for (int rank = 0; rank < num_ranks; ++rank) {
842     auto* device = this->GetDevice(num_ranks, /*node=*/0, rank);
843     auto* info = device->tensorflow_gpu_device_info();
844     auto* stream = device->tensorflow_gpu_device_info()->stream;
845     auto participant = absl::make_unique<NcclManager::Participant>(
846         device->executor(), stream, info, &test_case->outs[rank],
847         &test_case->outs[rank], rank,
848         this->CreateDoneCallback(test_case.get()));
849     NcclManager::instance()->AddBroadcastRecv(std::move(participant),
850                                               {"bcast_inconsistent_source",
851                                                /*num_local_devices=*/num_ranks,
852                                                /*num_global_devices=*/num_ranks,
853                                                /*communicator_key=*/"",
854                                                /*source_rank=*/rank});
855   }
856 
857   this->VerifyError(test_case.get());
858 }
859 
860 #if !TENSORFLOW_USE_ROCM
861 // ROCm platform currently does not support simulating a multi-node
862 // environment, on a single node with multiple GPUS. So tests that rely
863 // upon such simulation need to be skipped on the ROCm platform
864 
TYPED_TEST(NcclManagerTest,AbortThenReset)865 TYPED_TEST(NcclManagerTest, AbortThenReset) {
866   using NodeState = typename TestFixture::NodeState;
867   using TestCase = typename TestFixture::TestCase;
868   const int num_nodes = 2;
869   std::vector<NodeState> nodes(num_nodes);
870   // First do a normal all-reduce to simulate the case when there're
871   // multiple communicators.
872   this->RunMultiNodeAllReduceTest(nodes, /* num_ranks_per_node */ 1);
873 
874   const string collective_key = "allreduce";
875   ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(0);
876   auto node_fn = [&](TestCase* test_case, int node,
877                      const string& communicator_key) {
878     auto* device = this->GetDevice(/* num_ranks_per_node */ 1, node,
879                                    /* local_rank */ 0);
880     auto* info = device->tensorflow_gpu_device_info();
881     auto* stream = device->tensorflow_gpu_device_info()->stream;
882     auto participant = absl::make_unique<NcclManager::Participant>(
883         device->executor(), stream, info, &test_case->ins[node],
884         &test_case->outs[node], /* global_rank */ node,
885         this->CreateDoneCallback(test_case));
886     nodes[node].nccl_manager.AddToAllReduce(
887         std::move(participant),
888         {collective_key, /* num_local_devices */ 1,
889          /* num_global_devices */ num_nodes, communicator_key,
890          /*source_rank=*/-1},
891         reduction_op);
892     nodes[node].nccl_manager.SignalMultiNodeReady(collective_key);
893   };
894 
895   // Use a new communicator_key, which uses a new set of ncclComm underneath.
896   string communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey();
897   // Do a normal all-reduce with this communicator key to initialize ncclComm.
898   // This is because ncclCommInitRank waits for all ranks and is blocking.
899   {
900     std::unique_ptr<typename TestFixture::TestCase> test_case(
901         this->MakeReductionTestCase(
902             /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op,
903             TensorShape({2, 3}), 0.0f));
904     for (int i = 0; i < num_nodes; ++i) {
905       this->work_queue_->Schedule(
906           [&node_fn, &test_case, i, communicator_key]() {
907             node_fn(test_case.get(), i, communicator_key);
908           });
909     }
910     this->VerifyResults(test_case.get());
911   }
912 
913   // A hanging all-reduce.
914   ASSERT_GT(num_nodes, 1);
915   std::unique_ptr<typename TestFixture::TestCase> test_case(
916       this->MakeReductionTestCase(
917           /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op,
918           TensorShape({2, 3}), 0.0f));
919   node_fn(test_case.get(), 0, communicator_key);
920   Env::Default()->SleepForMicroseconds(1000000);
921   for (auto& node : nodes) {
922     node.nccl_manager.StartAbort(errors::Unavailable("peer down"));
923   }
924   {
925     mutex_lock l(test_case->mu);
926     while (test_case->num_completed != 1) {
927       test_case->done_cv.wait(l);
928     }
929   }
930 
931   // Reset the aborted NcclManager and then run another all-reduce with the
932   // resetted NcclManagers.
933   for (auto& node : nodes) {
934     node.nccl_manager.Reset();
935   }
936   // Regenerate the communicator_key, because this is needed to create new
937   // communicators.
938   communicator_key = nodes[0].nccl_manager.GenerateCommunicatorKey();
939   {
940     std::unique_ptr<typename TestFixture::TestCase> test_case(
941         this->MakeReductionTestCase(
942             /* num_nodes */ num_nodes, /* num_ranks_per_node */ 1, reduction_op,
943             TensorShape({2, 3}), 0.0f));
944     for (int i = 0; i < num_nodes; ++i) {
945       this->work_queue_->Schedule(
946           [&node_fn, &test_case, i, communicator_key]() {
947             node_fn(test_case.get(), i, communicator_key);
948           });
949     }
950     this->VerifyResults(test_case.get());
951   }
952 }
953 
954 #endif
955 
956 }  // namespace tensorflow
957 
958 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
959