1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
17
18 #include "tensorflow/compiler/xla/util.h"
19
20 #if GOOGLE_CUDA
21 #include "absl/synchronization/blocking_counter.h"
22 #include "third_party/nccl/nccl.h"
23 #include "tensorflow/core/lib/core/blocking_counter.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
26 #endif
27
28 namespace xla {
29 namespace gpu {
30
NcclIsEnabled()31 /* static */ bool NcclAllReduceThunk::NcclIsEnabled() {
32 #if GOOGLE_CUDA
33 return true;
34 #else
35 return false;
36 #endif
37 }
38
39 #if GOOGLE_CUDA
40 namespace {
41
42 // GPU-replica-driving host threads (i.e. the threads that call
43 // GpuExecutable::Execute) build up this structure to describe their
44 // participating replica, and then call to
45 // GlobalRendezvousManager::SubmitParticipant.
46 struct ParticipantData {
47 // Number of replicas particiating in the AllReduce.
48 int64 replica_count;
49
50 int64 element_count;
51 int64 device_ordinal;
52 int64 generation_counter;
53
54 // TODO(b/125951860): We should vet that we're buffer allocating such that
55 // source_buffer == destination_buffer if that avoids a NCCL copy (will depend
56 // on how well the NCCL in-place implementation performs vs the out-of-place
57 // implementation).
58 se::DeviceMemoryBase source_data;
59 se::DeviceMemoryBase destination_data;
60 se::Stream* stream;
61
62 NcclAllReduceThunk* originator;
63
ToStringxla::gpu::__anonfb4efd5a0111::ParticipantData64 string ToString() const {
65 return absl::StrFormat(
66 "ParticipantData{replica_count=%d, element_count=%d, "
67 "device_ordinal=%d, generation_counter=%d, stream=%p, originator=%p}",
68 replica_count, element_count, device_ordinal, generation_counter,
69 stream, originator);
70 }
71 };
72
73 // Class that gets instantiated as a singleton in GetGlobalRendezvous() to
74 // coordinate participating threads in performing an AllReduce operation.
75 //
76 // This manager is responsible for establishing communication channels and
77 // ultimately enqueueing the NCCL library operation onto the participating
78 // streams.
79 class GlobalRendezvousManager {
80 public:
81 // The GpuExecutable-executing threads call this in order to a) establish the
82 // all-reduce rendezvous and b) enqueue the AllReduce operation on the caller
83 // thread's associated stream (given in "participant").
84 //
85 // Implementation note: since the rendezvous we're creating here is global, we
86 // try to be paranoid about the fact that the *correct* one is happening. In
87 // an ideal world we'd have some StreamExecutor se::Platform level construct
88 // that we could use for cross-device networking primitives (e.g. via a
89 // NetworkSupport interface) that could be shared between TensorFlow and XLA,
90 // but this is a reasonable stopgap measure to get multi-GPU-replica up and
91 // running properly for single-host, single-concurrent-XLA-module usage.
92 Status SubmitParticipant(ParticipantData participant);
93
94 // Returns the current generation number of AllReduce operations.
95 // (Currently one AllReduce operation occurs per generation.)
GetCurrentGeneration()96 int64 GetCurrentGeneration() {
97 tensorflow::mutex_lock lock(mutex_);
98 return current_generation_;
99 }
100
101 private:
102 // Called by the primary thread to set up the communication links.
103 //
104 // TODO(b/125951860): This performs lots of (presumably) unnecessary host-side
105 // synchronization so that we can be paranoid about semantics in the earliest
106 // implementation. In the limit we should only need to synchronize host
107 // replica threads when the "number of replicas" or "participating device
108 // ordinals" change, to set up a new NCCL "communication" context, at which
109 // point we can enqueue onto device streams without host synchronization in
110 // our code -- this will likely be helpful for "lots of little AllReduce"
111 // cases.
112 Status InitializeCommunicationChannels() EXCLUSIVE_LOCKS_REQUIRED(mutex_);
113
114 // Called when all necessary participants are present, the functionality
115 // that's implemented by all executing threads lives in here.
116 Status DoAllReduce(ParticipantData data, ncclComm_t comm);
117
118 // Puts all state back into a "reset" state for the next generation of
119 // AllReduce requests.
DeinitializeGeneration()120 void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
121 for (ncclComm_t& comm : comms_) {
122 ncclCommDestroy(comm);
123 }
124 comms_.clear();
125 participants_.clear();
126 current_generation_++;
127 initialized_ = false;
128 done_ = absl::nullopt;
129 }
130
131 tensorflow::mutex mutex_;
132 tensorflow::condition_variable all_participants_present_;
133 tensorflow::condition_variable deinitialized_;
134
135 // Communication handles that correspond to the participants below.
136 std::vector<ncclComm_t> comms_ GUARDED_BY(mutex_);
137
138 Status initialize_status_ GUARDED_BY(mutex_);
139 std::vector<ParticipantData> participants_ GUARDED_BY(mutex_);
140 int64 current_generation_ GUARDED_BY(mutex_) = 0;
141 bool initialized_ GUARDED_BY(mutex_) = false;
142
143 // The participating threads wait for this to count down in order to know we
144 // can begin the teardown process.
145 absl::optional<tensorflow::BlockingCounter> done_;
146 };
147
SubmitParticipant(ParticipantData participant)148 Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
149 auto all_participants_present = [this, &participant]()
150 EXCLUSIVE_LOCKS_REQUIRED(mutex_) -> bool {
151 return participants_.size() >= participant.replica_count;
152 };
153
154 // We remember the participant index at which we are inserted and use that
155 // same index for referring to auxiliary metadata (e.g. the ncclComm_t handle
156 // index) below.
157 int64 index;
158
159 {
160 tensorflow::mutex_lock lock(mutex_);
161
162 // Spot check for consistent replica counts among submitting threads.
163 if (!participants_.empty() &&
164 (participants_.back().replica_count != participant.replica_count ||
165 participants_.back().originator != participant.originator)) {
166 return InvalidArgument(
167 "Running two XLA modules with AllReduces in parallel is not "
168 "supported. It is possible this is due to a bug where were try to "
169 "run two different AllReduces from the same module at once. "
170 "(Attempted a rendezvous with a different replica count from other "
171 "participants; existing: %s; submitted: %s)",
172 participants_.back().ToString(), participant.ToString());
173 }
174 index = participants_.size();
175 participants_.push_back(participant);
176
177 if (all_participants_present()) {
178 all_participants_present_.notify_all();
179 }
180 }
181
182 // We pull into our thread a) the communication handle and b) whether we're
183 // the "primary" thread for this rendezvous -- the "primary" thread has some
184 // additional responsibilities for setup/teardown.
185 ncclComm_t comm;
186 bool primary;
187
188 {
189 tensorflow::mutex_lock lock(mutex_);
190 while (!all_participants_present()) {
191 // Once all the participants have arrived, all participating threads will
192 // cross this barrier, though only (the first) one will be the "primary".
193 all_participants_present_.wait(lock);
194 }
195
196 // Somebody will be the first -- that thread has some additional
197 // responsibilities.
198 primary = !initialized_;
199
200 CHECK_EQ(participant.generation_counter, current_generation_);
201
202 // Bump the generation counter so the other threads know we've completed the
203 // global rendezvous and have set up the AllReduce.
204 if (primary) {
205 VLOG(3) << "Primary initializing accounting data.";
206 initialized_ = true;
207 done_.emplace(participant.replica_count);
208 initialize_status_ = InitializeCommunicationChannels();
209 VLOG(3) << "Done initializing communication channels; status: "
210 << initialize_status_;
211 if (!initialize_status_.ok()) {
212 DeinitializeGeneration();
213 }
214 }
215
216 if (!initialize_status_.ok()) {
217 // TODO(b/125951860): If this fails once, it will fail forever.
218 return initialize_status_;
219 }
220
221 comm = comms_[index];
222
223 // Drop the lock at the end of scope so other participants may enter.
224 }
225
226 VLOG(3) << "Performing all reduce from device ordinal: "
227 << participant.device_ordinal;
228
229 Status all_reduce_status = DoAllReduce(participant, comm);
230
231 VLOG(3) << "Waiting for all participants to complete enqueue.";
232
233 done_->DecrementCount();
234
235 if (primary) {
236 // Primary thread clears out the AllReduce state when everybody is done to
237 // make it clean-slate for any subsequent AllReduce request (e.g. number of
238 // replicas may change in the next request).
239 //
240 // Note surrounding TODOs for only reinitializing this when the replica
241 // count / participants actually change -- lots of "playing it safe"
242 // happening in this first cut.
243 done_->Wait();
244 VLOG(3) << "All participants completed enqueue.";
245 VLOG(3) << "Primary thread clearing.";
246 tensorflow::mutex_lock lock(mutex_);
247 DeinitializeGeneration();
248 VLOG(3) << "Generation is now: " << current_generation_;
249 deinitialized_.notify_all();
250 } else {
251 VLOG(3) << "Waiting to deinitialize.";
252 tensorflow::mutex_lock lock(mutex_);
253 while (initialized_) {
254 deinitialized_.wait(lock);
255 }
256 }
257
258 VLOG(3) << "Returning status: " << all_reduce_status;
259 return all_reduce_status;
260 }
261
InitializeCommunicationChannels()262 Status GlobalRendezvousManager::InitializeCommunicationChannels() {
263 std::vector<int> ordinals;
264 for (ParticipantData& data : participants_) {
265 ordinals.push_back(data.device_ordinal);
266 }
267 comms_.resize(ordinals.size());
268 VLOG(3) << "Participants: " << participants_.size()
269 << "; initializing comms.";
270 ncclResult_t result = ncclCommInitAll(comms_.data(), comms_.size(),
271 /*devlist=*/ordinals.data());
272 if (result != ncclSuccess) {
273 comms_.clear();
274 return InternalError(
275 "Failed to initialize NCCL communication channels for %d participants: "
276 "%s",
277 participants_.size(), ncclGetErrorString(result));
278 }
279 return Status::OK();
280 }
281
DoAllReduce(ParticipantData participant,ncclComm_t comm)282 Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant,
283 ncclComm_t comm) {
284 se::StreamExecutor* executor = participant.stream->parent();
285 se::cuda::ScopedActivateExecutorContext scoped_context(executor);
286 cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
287 participant.stream->implementation()->GpuStreamMemberHack());
288 VLOG(3) << "Using stream pointer: " << cu_stream
289 << " on device: " << participant.device_ordinal;
290 void* send_buffer = participant.source_data.opaque();
291 void* recv_buffer = participant.destination_data.opaque();
292 ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer,
293 /*count=*/participant.element_count,
294 /*datatype=*/ncclFloat,
295 /*op=*/ncclSum,
296 /*comm=*/comm,
297 /*stream=*/*cu_stream);
298 TF_RET_CHECK(ncclSuccess == result)
299 << "Failed to perform all-reduce: " << ncclGetErrorString(result);
300
301 VLOG(3) << "Done performing all reduce for ordinal: "
302 << participant.device_ordinal;
303
304 return Status::OK();
305 }
306
GetGlobalRendezvous()307 static GlobalRendezvousManager* GetGlobalRendezvous() {
308 static auto* manager = new GlobalRendezvousManager;
309 return manager;
310 }
311
312 } // namespace
313
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)314 Status NcclAllReduceThunk::ExecuteOnStream(
315 const BufferAllocations& buffer_allocations, se::Stream* stream,
316 HloExecutionProfiler* profiler) {
317 auto* global_rendezvous = GetGlobalRendezvous();
318
319 ParticipantData participant;
320 participant.replica_count = replica_count_;
321 participant.element_count = element_count_;
322 participant.device_ordinal = stream->parent()->device_ordinal();
323 participant.generation_counter = global_rendezvous->GetCurrentGeneration();
324 participant.source_data = buffer_allocations.GetDeviceAddress(source_buffer_);
325 participant.destination_data =
326 buffer_allocations.GetDeviceAddress(destination_buffer_);
327 participant.stream = stream;
328 participant.originator = this;
329
330 return GetGlobalRendezvous()->SubmitParticipant(std::move(participant));
331 }
332 #else
333
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)334 Status NcclAllReduceThunk::ExecuteOnStream(
335 const BufferAllocations& buffer_allocations, se::Stream* stream,
336 HloExecutionProfiler* profiler) {
337 return Unimplemented(
338 "NCCL support is not available: this binary was not built with a CUDA "
339 "compiler, which is necessary to build the NCCL source library.");
340 }
341
342 #endif // GOOGLE_CUDA
343
NcclAllReduceThunk(int64 replica_count,int64 element_count,const BufferAllocation::Slice & source_buffer,const BufferAllocation::Slice & destination_buffer,const HloInstruction * all_reduce)344 NcclAllReduceThunk::NcclAllReduceThunk(
345 int64 replica_count, int64 element_count,
346 const BufferAllocation::Slice& source_buffer,
347 const BufferAllocation::Slice& destination_buffer,
348 const HloInstruction* all_reduce)
349 : Thunk(Thunk::kNcclAllReduce, all_reduce),
350 replica_count_(replica_count),
351 element_count_(element_count),
352 source_buffer_(source_buffer),
353 destination_buffer_(destination_buffer) {}
354
355 } // namespace gpu
356 } // namespace xla
357