1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/nccl/nccl_manager.h"
16 
17 #include <utility>
18 
19 #ifdef GOOGLE_CUDA
20 
21 #include "tensorflow/core/lib/core/threadpool.h"
22 #include "tensorflow/core/platform/cuda.h"
23 #include "tensorflow/core/platform/env.h"
24 
25 namespace tensorflow {
26 
27 #define NCCL_RETURN_IF_ERROR(...)                               \
28   do {                                                          \
29     ncclResult_t nccl_status = (__VA_ARGS__);                   \
30     if (nccl_status != ncclSuccess) {                           \
31       return errors::Internal(ncclGetErrorString(nccl_status)); \
32     }                                                           \
33   } while (0)
34 
35 #define CUDA_RETURN_IF_ERROR(...)                               \
36   do {                                                          \
37     cudaError_t cuda_status = (__VA_ARGS__);                    \
38     if (cuda_status != cudaSuccess) {                           \
39       return errors::Internal(cudaGetErrorString(cuda_status)); \
40     }                                                           \
41   } while (0)
42 
43 using se::cuda::ScopedActivateExecutorContext;
44 
45 // Contains data for a single stream used for nccl communication; this includes
46 // a background thread that calls NcclManager::LoopKernelLaunches.
47 struct NcclManager::NcclStream {
48  public:
NcclStreamtensorflow::NcclManager::NcclStream49   NcclStream() {}
~NcclStreamtensorflow::NcclManager::NcclStream50   ~NcclStream() {
51     mutex_lock l(mu);
52     shutdown_requested = true;
53     cv.notify_all();
54   }
55 
56   se::StreamExecutor* executor = nullptr;
57 
58   // The stream on which to run the nccl collective.
59   // This is a different stream than the tensorflow compute stream.
60   std::unique_ptr<se::Stream> stream;
61 
62   // See NcclManager::LoopKernelLaunches for information on these.
63   std::unique_ptr<Thread> thread;
64   mutex mu;
65   condition_variable cv;
66   // Has collective,participant_idx pairs.
67   std::deque<std::pair<Collective*, int>> pending_launches_ GUARDED_BY(mu);
68   bool shutdown_requested GUARDED_BY(mu) = false;
69 };
70 
71 struct NcclManager::CommunicatorMember {
72  public:
CommunicatorMembertensorflow::NcclManager::CommunicatorMember73   CommunicatorMember() {}
~CommunicatorMembertensorflow::NcclManager::CommunicatorMember74   ~CommunicatorMember() {
75     if (nccl_comm != nullptr) ncclCommDestroy(nccl_comm);
76   }
77   ncclComm_t nccl_comm;
78 
79   // Owned by NcclManager::device_to_comm_streams_.
80   NcclStream* nccl_stream = nullptr;
81 };
82 
83 struct NcclManager::Communicator {
84  public:
Communicatortensorflow::NcclManager::Communicator85   explicit Communicator(std::vector<CommunicatorMember> members,
86                         const string& key)
87       : num_devices(members.size()), members(std::move(members)), key(key) {}
88 
89   const int num_devices;
90   const std::vector<CommunicatorMember> members;
91   const string key;
92 };
93 
94 namespace {
95 
ToNcclType(DataType t)96 ncclDataType_t ToNcclType(DataType t) {
97   switch (t) {
98     case DT_HALF:
99       return ncclHalf;
100     case DT_FLOAT:
101       return ncclFloat;
102     case DT_DOUBLE:
103       return ncclDouble;
104     case DT_INT32:
105       return ncclInt;
106     case DT_INT64:
107       return ncclInt64;
108     default:
109       return ncclFloat;
110   }
111 }
112 
StringToNcclUniqueId(const string & str_id,ncclUniqueId * nccl_id)113 void StringToNcclUniqueId(const string& str_id, ncclUniqueId* nccl_id) {
114   if (str_id.size() == NCCL_UNIQUE_ID_BYTES) {
115     memcpy(nccl_id->internal, str_id.data(), NCCL_UNIQUE_ID_BYTES);
116   }
117 }
118 
119 }  // namespace
120 
121 // A `Collective` encapsulates state for a collective instance at one node.
122 // Typically, an instance in TensorFlow context would be defined by a collective
123 // group and the (step, frame iteration) for that execution.
124 //
125 // For each collective instance there will be one `Collective` object per node.
126 // For example,  a NCCL collective that runs on a single node with 4 GPUs would
127 // have a single `Collective` per step.  However, a collective that executes on
128 // 3 nodes with 4 GPUs each would have a `Collective` per node, each of which is
129 // tracking the 4 GPUs local to that node.
130 struct NcclManager::Collective {
Collectivetensorflow::NcclManager::Collective131   Collective(DataType data_type_in, CollectiveType type_in,
132              ncclRedOp_t reduction_op_in, int num_local_devices_in,
133              int num_global_devices_in, const string& communicator_key_in)
134       : data_type(data_type_in),
135         type(type_in),
136         reduction_op(reduction_op_in),
137         num_local_devices(num_local_devices_in),
138         num_global_devices(num_global_devices_in),
139         single_node(num_local_devices_in == num_global_devices_in),
140         communicator_key(communicator_key_in),
141         remaining_participants(num_local_devices_in) {
142     participants.reserve(num_local_devices_in);
143   }
144 
145   const DataType data_type;
146   const CollectiveType type;
147   const ncclRedOp_t reduction_op;  // applies when <type> is a reduction.
148   const int num_local_devices;     // devices local to this node
149   const int num_global_devices;    // devices across all nodes
150   const bool single_node;          // true if all devices are at one node
151   const string communicator_key;
152 
153   Communicator* communicator = nullptr;
154 
155   // All collective participants.
156   //
157   // Adding values in this vector is guarded by the mutex of the containing
158   // NcclManager.
159   std::vector<std::unique_ptr<Participant>> participants;
160 
161   // For collective types that have a root (e.g. the root of broadcast is the
162   // sender), this is the rank of the root.
163   int root_rank = -1;
164 
165   // How many participants have been registered so far. The Collective is
166   // eligible for running with <available_participants> == num_local_devices.
167   //
168   // If this is a multi-node collective, we additionally have to synchronize
169   // across nodes.  The caller would need to signal multi node readiness by
170   // calling NcclManager::SignalMultiNodeReady, which sets `multi_node_ready` to
171   // true.
172   //
173   // Guarded by the mutex of the containing Communicator.
174   int available_participants = 0;
175   bool multi_node_ready = false;
176 
177   mutable std::atomic_int_fast32_t remaining_participants;
178 
179   Status status;
180 };
181 
NcclManager()182 NcclManager::NcclManager() {}
~NcclManager()183 NcclManager::~NcclManager() {}
instance()184 NcclManager* NcclManager::instance() {
185   static NcclManager* instance = new NcclManager();
186   return instance;
187 }
188 
GenerateCommunicatorKey()189 string NcclManager::GenerateCommunicatorKey() {
190   ncclUniqueId nccl_id;
191   ncclGetUniqueId(&nccl_id);
192   return string(nccl_id.internal, NCCL_UNIQUE_ID_BYTES);
193 }
194 
GetCommunicator(NcclManager::Collective * collective,NcclManager::Communicator ** communicator)195 Status NcclManager::GetCommunicator(NcclManager::Collective* collective,
196                                     NcclManager::Communicator** communicator) {
197   // Sort by executor to make ordering of executors deterministic.
198   std::sort(collective->participants.begin(), collective->participants.end(),
199             [](const std::unique_ptr<Participant>& a,
200                const std::unique_ptr<Participant>& b) {
201               return a->executor < b->executor;
202             });
203 
204   mutex_lock l(mu_);
205 
206   if (collective->single_node) {
207     // For single-node collectives, we identify a communicator uniquely by the
208     // set of devices participating in the collective.  For example, if a
209     // collective is for GPUs 0, 1, and 2 then this will scan to find the
210     // communicator for GPUs 0, 1, and 2.
211     //
212     // Note that each executor identifies a context on one device, so this is
213     // the same as getting the communicator connecting the devices in the
214     // collective. A device can be in different communicators as well - for
215     // example, a communicator for GPUs 0 and 1 is separate from one for GPUs 0,
216     // 1, and 2.
217     //
218     // Since it's expected that a small number of distinct communicators will
219     // be needed, communicators_ is not garbage collected currently.
220     //
221     // Launching of kernels must be serialized so that, given collectives A and
222     // B, and an order of them (e.g., A before B), then for each comm_stream
223     // involved, the kernel for A is launched before the kernel for B. This is
224     // guaranteed currently be a global mutex controlling additions of the
225     // kernels to per-stream launch queues.  The launch queues are processed by
226     // LoopKernelLaunches.
227     for (auto& comm : communicators_) {
228       if (comm->num_devices == collective->num_global_devices) {
229         int i;
230         for (i = 0; i < collective->num_local_devices; ++i) {
231           if (comm->members[i].nccl_stream->executor !=
232               collective->participants[i]->executor) {
233             break;
234           }
235         }
236         if (i == collective->num_local_devices) {
237           *communicator = comm.get();
238           return Status::OK();
239         }
240       }
241     }
242   } else {
243 #if NCCL_MAJOR < 2
244     return errors::Internal(
245         "Cannot use multi-node NCCL collectives with NCCL 1.x");
246 #endif
247     if (collective->communicator_key.size() != NCCL_UNIQUE_ID_BYTES) {
248       return errors::Internal("Expected communicator_key of size ",
249                               NCCL_UNIQUE_ID_BYTES, " but found size ",
250                               collective->communicator_key.size());
251     }
252     // This is an instance of multi-node collective.  We have previously
253     // created a NCCL unique id and shared with all workers.  Now we find the
254     // `Communicator` corresponding to this id.
255     for (auto& comm : communicators_) {
256       if (comm->key == collective->communicator_key) {
257         *communicator = comm.get();
258         return Status::OK();
259       }
260     }
261   }
262 
263   auto* env = Env::Default();
264   std::set<NcclStream*> used_streams;
265 
266   // Create and initialize a new communicator.
267   // Note that this is done under the lock; performance is not expected to
268   // matter as this happens a very small number of times.
269   std::vector<CommunicatorMember> members(collective->num_local_devices);
270   std::vector<int> devices(collective->num_local_devices);
271   for (int i = 0; i < collective->num_local_devices; ++i) {
272     auto* executor = collective->participants[i]->executor;
273 
274     // Find a communication stream to use for the device.
275     auto& streams = device_to_comm_streams_[executor];
276     NcclStream* nccl_stream = nullptr;
277     for (const auto& s : streams) {
278       if (used_streams.insert(s.get()).second) {
279         nccl_stream = s.get();
280         break;
281       }
282     }
283     if (nccl_stream == nullptr) {
284       nccl_stream = new NcclStream();
285       nccl_stream->executor = executor;
286       nccl_stream->stream.reset(new se::Stream(executor));
287       nccl_stream->stream->Init();
288 
289       streams.emplace_back(nccl_stream);
290       used_streams.insert(nccl_stream);
291 
292       nccl_stream->thread.reset(env->StartThread(
293           ThreadOptions(), "nccl_kernel_launch",
294           [this, nccl_stream] { LoopKernelLaunches(nccl_stream); }));
295     }
296 
297     members[i].nccl_stream = nccl_stream;
298     devices[i] = collective->participants[i]->gpu_device_id;
299   }
300 
301   std::vector<ncclComm_t> nccl_comms(collective->num_local_devices);
302 #if NCCL_MAJOR >= 2
303   // For NCCL 2, we always initialize using ncclCommInitRank guarded by NCCL
304   // group primitives.
305   ncclUniqueId nccl_id;
306   if (collective->single_node) {
307     NCCL_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id));
308   } else {
309     StringToNcclUniqueId(collective->communicator_key, &nccl_id);
310   }
311   int saved_device = 0;
312   CUDA_RETURN_IF_ERROR(cudaGetDevice(&saved_device));
313   NCCL_RETURN_IF_ERROR(ncclGroupStart());
314   for (int i = 0; i < collective->num_local_devices; ++i) {
315     // Set rank to `participant->global_rank` if provided, else `i`.
316     const int rank = collective->participants[i]->global_rank >= 0
317                          ? collective->participants[i]->global_rank
318                          : i;
319     CUDA_RETURN_IF_ERROR(cudaSetDevice(devices[i]));
320     NCCL_RETURN_IF_ERROR(ncclCommInitRank(
321         nccl_comms.data() + i, collective->num_global_devices, nccl_id, rank));
322   }
323   NCCL_RETURN_IF_ERROR(ncclGroupEnd());
324   CUDA_RETURN_IF_ERROR(cudaSetDevice(saved_device));
325 #else
326   // Since NCCL 1 is single node only, we use ncclCommInitAll.  We could have
327   // used ncclCommInitRank with NCCL 1 as well, but then we would have to
328   // issue each init call from a different thread
329   // (https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/nccl1.html).
330   NCCL_RETURN_IF_ERROR(ncclCommInitAll(
331       nccl_comms.data(), collective->num_local_devices, devices.data()));
332 #endif
333 
334   for (int i = 0; i < collective->num_local_devices; ++i) {
335     members[i].nccl_comm = nccl_comms[i];
336   }
337   communicators_.emplace_back(
338       new Communicator(std::move(members), collective->communicator_key));
339   *communicator = communicators_.back().get();
340   return Status::OK();
341 }
342 
AddToAllReduce(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)343 void NcclManager::AddToAllReduce(std::unique_ptr<Participant> participant,
344                                  const Context& context,
345                                  ncclRedOp_t reduction_op) {
346   AddParticipant(std::move(participant), context, kAllReduce, reduction_op);
347 }
348 
AddToAllGather(std::unique_ptr<Participant> participant,const Context & context)349 void NcclManager::AddToAllGather(std::unique_ptr<Participant> participant,
350                                  const Context& context) {
351   AddParticipant(std::move(participant), context, kAllGather,
352                  ncclSum /* unused */);
353 }
354 
AddBroadcastSend(std::unique_ptr<Participant> participant,const Context & context)355 void NcclManager::AddBroadcastSend(std::unique_ptr<Participant> participant,
356                                    const Context& context) {
357   participant->root = true;
358   AddParticipant(std::move(participant), context, kBroadcast,
359                  ncclSum /* unused */);
360 }
361 
AddBroadcastRecv(std::unique_ptr<Participant> participant,const Context & context)362 void NcclManager::AddBroadcastRecv(std::unique_ptr<Participant> participant,
363                                    const Context& context) {
364   AddParticipant(std::move(participant), context, kBroadcast,
365                  ncclSum /* unused */);
366 }
367 
AddReduceSend(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)368 void NcclManager::AddReduceSend(std::unique_ptr<Participant> participant,
369                                 const Context& context,
370                                 ncclRedOp_t reduction_op) {
371   AddParticipant(std::move(participant), context, kReduce, reduction_op);
372 }
373 
AddReduceRecv(std::unique_ptr<Participant> participant,const Context & context,ncclRedOp_t reduction_op)374 void NcclManager::AddReduceRecv(std::unique_ptr<Participant> participant,
375                                 const Context& context,
376                                 ncclRedOp_t reduction_op) {
377   AddParticipant(std::move(participant), context, kReduce, reduction_op);
378 }
379 
SignalMultiNodeReady(const string & collective_key)380 void NcclManager::SignalMultiNodeReady(const string& collective_key) {
381   Collective* to_run = nullptr;
382   {
383     mutex_lock l(mu_);
384     auto collective_it = collectives_.find(collective_key);
385     if (collective_it != collectives_.end()) {
386       Collective* collective = collective_it->second.get();
387       collective->multi_node_ready = true;
388       to_run = CheckReady(collective_key, collective);
389     }
390   }
391 
392   if (to_run != nullptr) RunCollective(to_run);
393 }
394 
AddParticipant(std::unique_ptr<Participant> participant,const Context & context,CollectiveType collective_type,ncclRedOp_t reduction_op)395 void NcclManager::AddParticipant(std::unique_ptr<Participant> participant,
396                                  const Context& context,
397                                  CollectiveType collective_type,
398                                  ncclRedOp_t reduction_op) {
399   Collective* to_run = nullptr;
400   const DataType data_type = participant->input->dtype();
401   {
402     mutex_lock l(mu_);
403     auto collective_it = collectives_.find(context.collective_key);
404     Collective* collective = nullptr;
405     if (collective_it == collectives_.end()) {
406       auto collective_unique_ptr = absl::make_unique<Collective>(
407           data_type, collective_type, reduction_op, context.num_local_devices,
408           context.num_global_devices, context.communicator_key);
409       collective = collective_unique_ptr.get();
410       collectives_.emplace(context.collective_key,
411                            std::move(collective_unique_ptr));
412     } else {
413       collective = collective_it->second.get();
414     }
415 
416     // Check `collective` is correct and consistent.
417     if (collective->status.ok() && collective->single_node &&
418         !collective->communicator_key.empty()) {
419       collective->status =
420           errors::Internal("Collective ", reduction_op,
421                            " is single node but has communicator_key of size ",
422                            collective->communicator_key.size());
423     }
424     if (collective->status.ok() && collective->communicator_key.size() !=
425                                        context.communicator_key.size()) {
426       collective->status =
427           errors::Internal("Collective ", reduction_op,
428                            " mismatch in member communicator_key with size ",
429                            collective->communicator_key.size(),
430                            " and arg communicator_key with size ",
431                            context.communicator_key.size());
432     }
433     if (collective->status.ok() && collective->type != collective_type) {
434       collective->status = errors::Internal(
435           "Collective ", reduction_op, " previously initialized with type ",
436           collective->type, " but now got type ", collective_type);
437     }
438     if (collective->status.ok() &&
439         collective->num_global_devices != context.num_global_devices) {
440       collective->status =
441           errors::Internal("Collective ", reduction_op,
442                            " previously initialized with num_global_devices ",
443                            collective->num_global_devices, " but now got ",
444                            context.num_global_devices);
445     }
446     if (collective->status.ok() &&
447         collective->num_local_devices != context.num_local_devices) {
448       collective->status =
449           errors::Internal("Collective ", reduction_op,
450                            "previously initialized with num_local_devices ",
451                            collective->num_local_devices, " but now got ",
452                            context.num_local_devices);
453     }
454     if (collective->status.ok() &&
455         collective->participants.size() >= collective->num_local_devices) {
456       collective->status = errors::Internal(
457           "Collective ", reduction_op, " expected ",
458           collective->num_local_devices, " participants but now has ",
459           collective->participants.size(),
460           " with one more participant being added");
461     }
462 
463     collective->participants.emplace_back(std::move(participant));
464     ++collective->available_participants;
465 
466     to_run = CheckReady(context.collective_key, collective);
467   }
468 
469   if (to_run != nullptr) RunCollective(to_run);
470 }
471 
CheckReady(const string & collective_key,Collective * collective)472 NcclManager::Collective* NcclManager::CheckReady(const string& collective_key,
473                                                  Collective* collective) {
474   Collective* to_run = nullptr;
475   if (collective->available_participants == collective->num_local_devices) {
476     if (collective->num_global_devices == collective->num_local_devices ||
477         collective->multi_node_ready) {
478       // Ownership transferred to callee.
479       to_run = collective;
480       auto collectives_it = collectives_.find(collective_key);
481       collectives_it->second.release();
482       collectives_.erase(collectives_it);
483     }
484   }
485   return to_run;
486 }
487 
RunCollective(Collective * collective)488 void NcclManager::RunCollective(Collective* collective) {
489   static mutex collective_mu(LINKER_INITIALIZED);
490 
491   Status s = collective->status;
492   if (s.ok()) {
493     s = GetCommunicator(collective, &collective->communicator);
494   }
495   if (!s.ok()) {
496     for (int i = 0; i < collective->num_local_devices; ++i) {
497       collective->participants[i]->done_callback(s);
498     }
499     delete collective;
500     return;
501   }
502 
503   for (int i = 0; i < collective->num_local_devices; ++i) {
504     Participant* p = collective->participants[i].get();
505     NcclStream* nccl_stream = collective->communicator->members[i].nccl_stream;
506     CHECK(nccl_stream != nullptr);
507     const int rank = p->global_rank >= 0 ? p->global_rank : i;
508 
509     if (p->input != nullptr) {
510       // Wait to ensure that the kernel that produces the data in the input
511       // tensor has finished running before the nccl kernel runs on the
512       // communication stream.
513       nccl_stream->stream->ThenWaitFor(p->tensor_stream);
514     }
515     if (p->root) {
516       CHECK_EQ(collective->root_rank, -1);
517       collective->root_rank = rank;
518     }
519   }
520 
521   if (collective->type == kBroadcast) {
522     CHECK_NE(collective->root_rank, -1);
523   }
524 
525   {
526     // Allow only one collective at a time to queue kernels for launching. This
527     // is to prevent collectives from deadlocking each other.
528     // Note that it would be possible to run multiple collectives at once, if
529     // they have non-intersecting sets of devices.
530     mutex_lock l(collective_mu);
531     for (int i = 0; i < collective->num_local_devices; ++i) {
532       NcclStream* nccl_stream =
533           collective->communicator->members[i].nccl_stream;
534       mutex_lock l(nccl_stream->mu);
535       nccl_stream->pending_launches_.push_front(std::make_pair(collective, i));
536       nccl_stream->cv.notify_all();
537     }
538   }
539 }
540 
LoopKernelLaunches(NcclStream * nccl_stream)541 void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
542   se::Stream* comm_stream = nccl_stream->stream.get();
543   ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
544   const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
545       comm_stream->implementation()->GpuStreamMemberHack());
546 
547   while (true) {
548     // Find collective to run.
549     std::pair<Collective*, int> next_launch;
550     {
551       mutex_lock l(nccl_stream->mu);
552       while (nccl_stream->pending_launches_.empty()) {
553         if (nccl_stream->shutdown_requested) {
554           // No work and shutdown requested, exit.
555           return;
556         }
557         nccl_stream->cv.wait(l);
558       }
559       next_launch = nccl_stream->pending_launches_.back();
560       nccl_stream->pending_launches_.pop_back();
561     }
562 
563     // Launch the nccl kernel.
564     Collective* collective = next_launch.first;
565     ncclDataType_t data_type = ToNcclType(collective->data_type);
566     int p_idx = next_launch.second;
567     Participant* p = collective->participants[p_idx].get();
568     auto nccl_comm = collective->communicator->members[p_idx].nccl_comm;
569     ncclResult_t nccl_result = ncclSuccess;
570     switch (collective->type) {
571       case kAllReduce: {
572         const void* sendbuff = p->input->tensor_data().data();
573         void* recvbuff = const_cast<char*>(p->output->tensor_data().data());
574 
575         VLOG(2) << "call NcclAllReduce participant " << p_idx << " sendbuff "
576                 << sendbuff << " recvbuff " << recvbuff << " nccl_comm "
577                 << nccl_comm << " comm_stream " << comm_stream
578                 << " cuda_stream " << cu_stream;
579         nccl_result = ncclAllReduce(sendbuff, recvbuff, p->input->NumElements(),
580                                     data_type, collective->reduction_op,
581                                     nccl_comm, *cu_stream);
582         break;
583       }
584       case kBroadcast: {
585         const Tensor* buf_t = p->input ? p->input : p->output;
586         void* buf = const_cast<char*>(buf_t->tensor_data().data());
587         nccl_result = ncclBcast(buf, buf_t->NumElements(), data_type,
588                                 collective->root_rank, nccl_comm, *cu_stream);
589         break;
590       }
591       case kReduce: {
592         const void* sendbuff = p->input->tensor_data().data();
593         void* recvbuff =
594             p->output ? const_cast<char*>(p->output->tensor_data().data())
595                       : nullptr;
596         nccl_result = ncclReduce(sendbuff, recvbuff, p->input->NumElements(),
597                                  data_type, collective->reduction_op,
598                                  collective->root_rank, nccl_comm, *cu_stream);
599         break;
600       }
601       case kAllGather: {
602         const void* sendbuff = p->input->tensor_data().data();
603         void* recvbuff = const_cast<char*>(p->output->tensor_data().data());
604 
605         VLOG(2) << "call NcclAllGather participant " << p_idx << " sendbuff "
606                 << sendbuff << " sendcount " << p->input->NumElements()
607                 << " recvbuff " << recvbuff << " recvcount "
608                 << p->output->NumElements() << " nccl_comm " << nccl_comm
609                 << " comm_stream " << comm_stream << " cuda_stream "
610                 << cu_stream;
611         nccl_result = ncclAllGather(sendbuff, recvbuff, p->input->NumElements(),
612                                     data_type, nccl_comm, *cu_stream);
613         break;
614       }
615     }
616 
617     // Run the done_callback when the nccl kernel finishes running.
618     auto done_callback = [collective, p_idx, nccl_result]() {
619       if (nccl_result == ncclSuccess) {
620         collective->participants[p_idx]->done_callback(Status::OK());
621       } else {
622         // Propagate the error, but note that if other members of the collective
623         // did launch their kernels, then they are hanging.
624         collective->participants[p_idx]->done_callback(errors::Unknown(
625             "Error invoking NCCL: ", ncclGetErrorString(nccl_result)));
626       }
627 
628       // TODO(cwhipkey): use RefCounted after figuring out how to use in a
629       // custom op library.
630       // See tensorflow/core/lib/core/refcount.h for details on this locking.
631       if (collective->remaining_participants.load(std::memory_order_acquire) ==
632               1 ||
633           collective->remaining_participants.fetch_sub(1) == 1) {
634         delete collective;
635       }
636     };
637     p->event_mgr->ThenExecute(comm_stream, done_callback);
638   }
639 }
640 
641 }  // namespace tensorflow
642 
643 #endif  // GOOGLE_CUDA
644