1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
17 
18 #include "tensorflow/compiler/xla/util.h"
19 
20 #if GOOGLE_CUDA
21 #include "absl/synchronization/blocking_counter.h"
22 #include "third_party/nccl/nccl.h"
23 #include "tensorflow/core/lib/core/blocking_counter.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
26 #endif
27 
28 namespace xla {
29 namespace gpu {
30 
NcclIsEnabled()31 /* static */ bool NcclAllReduceThunk::NcclIsEnabled() {
32 #if GOOGLE_CUDA
33   return true;
34 #else
35   return false;
36 #endif
37 }
38 
39 #if GOOGLE_CUDA
40 namespace {
41 
42 // GPU-replica-driving host threads (i.e. the threads that call
43 // GpuExecutable::Execute) build up this structure to describe their
44 // participating replica, and then call to
45 // GlobalRendezvousManager::SubmitParticipant.
46 struct ParticipantData {
47   // Number of replicas particiating in the AllReduce.
48   int64 replica_count;
49 
50   int64 element_count;
51   int64 device_ordinal;
52   int64 generation_counter;
53 
54   // TODO(b/125951860): We should vet that we're buffer allocating such that
55   // source_buffer == destination_buffer if that avoids a NCCL copy (will depend
56   // on how well the NCCL in-place implementation performs vs the out-of-place
57   // implementation).
58   se::DeviceMemoryBase source_data;
59   se::DeviceMemoryBase destination_data;
60   se::Stream* stream;
61 
62   NcclAllReduceThunk* originator;
63 
ToStringxla::gpu::__anonfb4efd5a0111::ParticipantData64   string ToString() const {
65     return absl::StrFormat(
66         "ParticipantData{replica_count=%d, element_count=%d, "
67         "device_ordinal=%d, generation_counter=%d, stream=%p, originator=%p}",
68         replica_count, element_count, device_ordinal, generation_counter,
69         stream, originator);
70   }
71 };
72 
73 // Class that gets instantiated as a singleton in GetGlobalRendezvous() to
74 // coordinate participating threads in performing an AllReduce operation.
75 //
76 // This manager is responsible for establishing communication channels and
77 // ultimately enqueueing the NCCL library operation onto the participating
78 // streams.
79 class GlobalRendezvousManager {
80  public:
81   // The GpuExecutable-executing threads call this in order to a) establish the
82   // all-reduce rendezvous and b) enqueue the AllReduce operation on the caller
83   // thread's associated stream (given in "participant").
84   //
85   // Implementation note: since the rendezvous we're creating here is global, we
86   // try to be paranoid about the fact that the *correct* one is happening.  In
87   // an ideal world we'd have some StreamExecutor se::Platform level construct
88   // that we could use for cross-device networking primitives (e.g. via a
89   // NetworkSupport interface) that could be shared between TensorFlow and XLA,
90   // but this is a reasonable stopgap measure to get multi-GPU-replica up and
91   // running properly for single-host, single-concurrent-XLA-module usage.
92   Status SubmitParticipant(ParticipantData participant);
93 
94   // Returns the current generation number of AllReduce operations.
95   // (Currently one AllReduce operation occurs per generation.)
GetCurrentGeneration()96   int64 GetCurrentGeneration() {
97     tensorflow::mutex_lock lock(mutex_);
98     return current_generation_;
99   }
100 
101  private:
102   // Called by the primary thread to set up the communication links.
103   //
104   // TODO(b/125951860): This performs lots of (presumably) unnecessary host-side
105   // synchronization so that we can be paranoid about semantics in the earliest
106   // implementation. In the limit we should only need to synchronize host
107   // replica threads when the "number of replicas" or "participating device
108   // ordinals" change, to set up a new NCCL "communication" context, at which
109   // point we can enqueue onto device streams without host synchronization in
110   // our code -- this will likely be helpful for "lots of little AllReduce"
111   // cases.
112   Status InitializeCommunicationChannels() EXCLUSIVE_LOCKS_REQUIRED(mutex_);
113 
114   // Called when all necessary participants are present, the functionality
115   // that's implemented by all executing threads lives in here.
116   Status DoAllReduce(ParticipantData data, ncclComm_t comm);
117 
118   // Puts all state back into a "reset" state for the next generation of
119   // AllReduce requests.
DeinitializeGeneration()120   void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
121     for (ncclComm_t& comm : comms_) {
122       ncclCommDestroy(comm);
123     }
124     comms_.clear();
125     participants_.clear();
126     current_generation_++;
127     initialized_ = false;
128     done_ = absl::nullopt;
129   }
130 
131   tensorflow::mutex mutex_;
132   tensorflow::condition_variable all_participants_present_;
133   tensorflow::condition_variable deinitialized_;
134 
135   // Communication handles that correspond to the participants below.
136   std::vector<ncclComm_t> comms_ GUARDED_BY(mutex_);
137 
138   Status initialize_status_ GUARDED_BY(mutex_);
139   std::vector<ParticipantData> participants_ GUARDED_BY(mutex_);
140   int64 current_generation_ GUARDED_BY(mutex_) = 0;
141   bool initialized_ GUARDED_BY(mutex_) = false;
142 
143   // The participating threads wait for this to count down in order to know we
144   // can begin the teardown process.
145   absl::optional<tensorflow::BlockingCounter> done_;
146 };
147 
SubmitParticipant(ParticipantData participant)148 Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
149   auto all_participants_present = [this, &participant]()
150                                       EXCLUSIVE_LOCKS_REQUIRED(mutex_) -> bool {
151     return participants_.size() >= participant.replica_count;
152   };
153 
154   // We remember the participant index at which we are inserted and use that
155   // same index for referring to auxiliary metadata (e.g. the ncclComm_t handle
156   // index) below.
157   int64 index;
158 
159   {
160     tensorflow::mutex_lock lock(mutex_);
161 
162     // Spot check for consistent replica counts among submitting threads.
163     if (!participants_.empty() &&
164         (participants_.back().replica_count != participant.replica_count ||
165          participants_.back().originator != participant.originator)) {
166       return InvalidArgument(
167           "Running two XLA modules with AllReduces in parallel is not "
168           "supported. It is possible this is due to a bug where were try to "
169           "run two different AllReduces from the same module at once. "
170           "(Attempted a rendezvous with a different replica count from other "
171           "participants; existing: %s; submitted: %s)",
172           participants_.back().ToString(), participant.ToString());
173     }
174     index = participants_.size();
175     participants_.push_back(participant);
176 
177     if (all_participants_present()) {
178       all_participants_present_.notify_all();
179     }
180   }
181 
182   // We pull into our thread a) the communication handle and b) whether we're
183   // the "primary" thread for this rendezvous -- the "primary" thread has some
184   // additional responsibilities for setup/teardown.
185   ncclComm_t comm;
186   bool primary;
187 
188   {
189     tensorflow::mutex_lock lock(mutex_);
190     while (!all_participants_present()) {
191       // Once all the participants have arrived, all participating threads will
192       // cross this barrier, though only (the first) one will be the "primary".
193       all_participants_present_.wait(lock);
194     }
195 
196     // Somebody will be the first -- that thread has some additional
197     // responsibilities.
198     primary = !initialized_;
199 
200     CHECK_EQ(participant.generation_counter, current_generation_);
201 
202     // Bump the generation counter so the other threads know we've completed the
203     // global rendezvous and have set up the AllReduce.
204     if (primary) {
205       VLOG(3) << "Primary initializing accounting data.";
206       initialized_ = true;
207       done_.emplace(participant.replica_count);
208       initialize_status_ = InitializeCommunicationChannels();
209       VLOG(3) << "Done initializing communication channels; status: "
210               << initialize_status_;
211       if (!initialize_status_.ok()) {
212         DeinitializeGeneration();
213       }
214     }
215 
216     if (!initialize_status_.ok()) {
217       // TODO(b/125951860): If this fails once, it will fail forever.
218       return initialize_status_;
219     }
220 
221     comm = comms_[index];
222 
223     // Drop the lock at the end of scope so other participants may enter.
224   }
225 
226   VLOG(3) << "Performing all reduce from device ordinal: "
227           << participant.device_ordinal;
228 
229   Status all_reduce_status = DoAllReduce(participant, comm);
230 
231   VLOG(3) << "Waiting for all participants to complete enqueue.";
232 
233   done_->DecrementCount();
234 
235   if (primary) {
236     // Primary thread clears out the AllReduce state when everybody is done to
237     // make it clean-slate for any subsequent AllReduce request (e.g. number of
238     // replicas may change in the next request).
239     //
240     // Note surrounding TODOs for only reinitializing this when the replica
241     // count / participants actually change -- lots of "playing it safe"
242     // happening in this first cut.
243     done_->Wait();
244     VLOG(3) << "All participants completed enqueue.";
245     VLOG(3) << "Primary thread clearing.";
246     tensorflow::mutex_lock lock(mutex_);
247     DeinitializeGeneration();
248     VLOG(3) << "Generation is now: " << current_generation_;
249     deinitialized_.notify_all();
250   } else {
251     VLOG(3) << "Waiting to deinitialize.";
252     tensorflow::mutex_lock lock(mutex_);
253     while (initialized_) {
254       deinitialized_.wait(lock);
255     }
256   }
257 
258   VLOG(3) << "Returning status: " << all_reduce_status;
259   return all_reduce_status;
260 }
261 
InitializeCommunicationChannels()262 Status GlobalRendezvousManager::InitializeCommunicationChannels() {
263   std::vector<int> ordinals;
264   for (ParticipantData& data : participants_) {
265     ordinals.push_back(data.device_ordinal);
266   }
267   comms_.resize(ordinals.size());
268   VLOG(3) << "Participants: " << participants_.size()
269           << "; initializing comms.";
270   ncclResult_t result = ncclCommInitAll(comms_.data(), comms_.size(),
271                                         /*devlist=*/ordinals.data());
272   if (result != ncclSuccess) {
273     comms_.clear();
274     return InternalError(
275         "Failed to initialize NCCL communication channels for %d participants: "
276         "%s",
277         participants_.size(), ncclGetErrorString(result));
278   }
279   return Status::OK();
280 }
281 
DoAllReduce(ParticipantData participant,ncclComm_t comm)282 Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant,
283                                             ncclComm_t comm) {
284   se::StreamExecutor* executor = participant.stream->parent();
285   se::cuda::ScopedActivateExecutorContext scoped_context(executor);
286   cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
287       participant.stream->implementation()->GpuStreamMemberHack());
288   VLOG(3) << "Using stream pointer: " << cu_stream
289           << " on device: " << participant.device_ordinal;
290   void* send_buffer = participant.source_data.opaque();
291   void* recv_buffer = participant.destination_data.opaque();
292   ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer,
293                                       /*count=*/participant.element_count,
294                                       /*datatype=*/ncclFloat,
295                                       /*op=*/ncclSum,
296                                       /*comm=*/comm,
297                                       /*stream=*/*cu_stream);
298   TF_RET_CHECK(ncclSuccess == result)
299       << "Failed to perform all-reduce: " << ncclGetErrorString(result);
300 
301   VLOG(3) << "Done performing all reduce for ordinal: "
302           << participant.device_ordinal;
303 
304   return Status::OK();
305 }
306 
GetGlobalRendezvous()307 static GlobalRendezvousManager* GetGlobalRendezvous() {
308   static auto* manager = new GlobalRendezvousManager;
309   return manager;
310 }
311 
312 }  // namespace
313 
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)314 Status NcclAllReduceThunk::ExecuteOnStream(
315     const BufferAllocations& buffer_allocations, se::Stream* stream,
316     HloExecutionProfiler* profiler) {
317   auto* global_rendezvous = GetGlobalRendezvous();
318 
319   ParticipantData participant;
320   participant.replica_count = replica_count_;
321   participant.element_count = element_count_;
322   participant.device_ordinal = stream->parent()->device_ordinal();
323   participant.generation_counter = global_rendezvous->GetCurrentGeneration();
324   participant.source_data = buffer_allocations.GetDeviceAddress(source_buffer_);
325   participant.destination_data =
326       buffer_allocations.GetDeviceAddress(destination_buffer_);
327   participant.stream = stream;
328   participant.originator = this;
329 
330   return GetGlobalRendezvous()->SubmitParticipant(std::move(participant));
331 }
332 #else
333 
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)334 Status NcclAllReduceThunk::ExecuteOnStream(
335     const BufferAllocations& buffer_allocations, se::Stream* stream,
336     HloExecutionProfiler* profiler) {
337   return Unimplemented(
338       "NCCL support is not available: this binary was not built with a CUDA "
339       "compiler, which is necessary to build the NCCL source library.");
340 }
341 
342 #endif  // GOOGLE_CUDA
343 
NcclAllReduceThunk(int64 replica_count,int64 element_count,const BufferAllocation::Slice & source_buffer,const BufferAllocation::Slice & destination_buffer,const HloInstruction * all_reduce)344 NcclAllReduceThunk::NcclAllReduceThunk(
345     int64 replica_count, int64 element_count,
346     const BufferAllocation::Slice& source_buffer,
347     const BufferAllocation::Slice& destination_buffer,
348     const HloInstruction* all_reduce)
349     : Thunk(Thunk::kNcclAllReduce, all_reduce),
350       replica_count_(replica_count),
351       element_count_(element_count),
352       source_buffer_(source_buffer),
353       destination_buffer_(destination_buffer) {}
354 
355 }  // namespace gpu
356 }  // namespace xla
357