1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
17
18 #include <chrono> // NOLINT (required by TF interfaces)
19 #include <map>
20 #include <memory>
21 #include <vector>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/memory/memory.h"
25 #include "tensorflow/compiler/xla/refcounting_hash_map.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/core/lib/core/blocking_counter.h"
31 #include "tensorflow/core/platform/mutex.h"
32
33 // This thunk's implementation is somewhat similar to our implementation of
34 // AllReduce using NCCL. One reason it's separate is that, because this doesn't
35 // use NCCL, it can work even without a CUDA compiler.
36
37 namespace xla {
38 namespace gpu {
39
40 namespace {
41
42 using tensorflow::BlockingCounter;
43
44 // This same function appears in nccl_all_reduce_thunk. I've copy/pasted it
45 // here primarily because I want the VLOGs to work.
46 template <typename DescFn>
WaitAndLogIfStuck(tensorflow::BlockingCounter * counter,const DescFn & desc_fn)47 void WaitAndLogIfStuck(tensorflow::BlockingCounter* counter,
48 const DescFn& desc_fn) {
49 VLOG(3) << "Begin: " << desc_fn();
50 const std::chrono::milliseconds timeout(5000);
51 bool ok = counter->WaitFor(timeout);
52 if (ok) {
53 VLOG(3) << "Finished: " << desc_fn();
54 return;
55 }
56 LOG(ERROR) << "This thread has been waiting for " << timeout.count()
57 << "ms for and may be stuck: " << desc_fn();
58 counter->Wait();
59 LOG(ERROR) << "Thread is unstuck! Warning above was a false-positive. "
60 "Perhaps the timeout is too short: "
61 << desc_fn();
62 }
63
64 // Key for looking up a Rendezvous object in our global map.
65 //
66 // Morally, the key is just a RunId. num_participants is in this struct only
67 // because we use that information when constructing the Rendezvous.
68 struct RendezvousKey {
69 RunId run_id;
70 int num_participants; // int, not int64, to match BlockingCounter's counter.
71
ToStringxla::gpu::__anonae9547960111::RendezvousKey72 string ToString() const {
73 return absl::StrFormat("RendezvousKey{run_id=%s, num_participants=%d}",
74 run_id.ToString(), num_participants);
75 }
76
77 template <typename H>
AbslHashValue(H h,const RendezvousKey & k)78 friend H AbslHashValue(H h, const RendezvousKey& k) {
79 return H::combine(std::move(h), k.run_id);
80 }
operator ==(const RendezvousKey & a,const RendezvousKey & b)81 friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) {
82 return a.run_id == b.run_id;
83 }
operator !=(const RendezvousKey & a,const RendezvousKey & b)84 friend bool operator!=(const RendezvousKey& a, const RendezvousKey& b) {
85 return !(a == b);
86 }
87 };
88
89 // Information about a thread that's participating in a collective-permute
90 // operation.
91 struct ParticipantData {
92 int64 replica_id;
93 se::Stream* stream;
94
95 se::DeviceMemoryBase src;
96 se::DeviceMemoryBase dest;
97
98 // ReplicaIds to which we will copy the data in src.
99 std::vector<int64> dest_replicas;
100 };
101
102 // The set of threads that want to do a collective permute together all pick the
103 // same Rendezvous object out of the global cache and call SubmitParticipant.
104 //
105 // The Rendezvous instance handles waiting for all threads to join and then
106 // doing the actual work of the collective permute.
107 //
108 // Rendezvous objects can only be used once.
109 class Rendezvous {
110 public:
Rendezvous(const RendezvousKey & key)111 explicit Rendezvous(const RendezvousKey& key) : key_(key) {}
112
113 // Runs the collective permute on the given thread.
114 //
115 // If successful, returns a BlockingCounter initialized to the number of
116 // participants, so that the caller can coordinate with the participants one
117 // last time if it chooses. This is useful for coordinating destruction of
118 // the Rendezvous.
119 StatusOr<std::shared_ptr<BlockingCounter>> SubmitParticipant(
120 ParticipantData participant);
121
122 private:
123 const RendezvousKey key_;
124 BlockingCounter all_arrived_{key_.num_participants};
125
126 // BlockingCounter returned by SubmitParticipant.
127 std::shared_ptr<BlockingCounter> returned_blocking_counter_{
128 std::make_shared<BlockingCounter>(key_.num_participants)};
129
130 tensorflow::mutex mu_;
131 bool initialized_ TF_GUARDED_BY(mu_) = false;
132
133 // We use an std::map so that we can iterate over it below in a guaranteed
134 // order. The order shouldn't actually matter, but why be nondeterministic if
135 // we don't have to be?
136 std::map<int64, ParticipantData> participants_ TF_GUARDED_BY(mu_);
137 };
138
EnqueueCopy(se::DeviceMemoryBase src,se::Stream * src_stream,se::DeviceMemoryBase dest,se::Stream * dest_stream)139 void EnqueueCopy(se::DeviceMemoryBase src, se::Stream* src_stream,
140 se::DeviceMemoryBase dest, se::Stream* dest_stream) {
141 CHECK_EQ(src.size(), dest.size());
142
143 // If src_stream == dest_stream, we're doing a plain memcpy from one GPU back
144 // to the same GPU. x->ThenWaitFor(x) is illegal, so this has to be a special
145 // case.
146 if (src_stream == dest_stream) {
147 dest_stream->ThenMemcpyD2D(&dest, src, src.size());
148 return;
149 }
150
151 // We (arbitrarily) choose to use the dest stream do perform the copy. This
152 // means we need to make the dest stream wait until the src stream is ready
153 // before it performs the copy, and then we need to make the src stream wait
154 // until the dest stream has completed the copy.
155 dest_stream->ThenWaitFor(src_stream).ThenMemcpyD2D(&dest, src, src.size());
156 src_stream->ThenWaitFor(dest_stream);
157 }
158
SubmitParticipant(ParticipantData participant)159 StatusOr<std::shared_ptr<BlockingCounter>> Rendezvous::SubmitParticipant(
160 ParticipantData participant) {
161 bool primary;
162 {
163 tensorflow::mutex_lock lock(mu_);
164 CHECK(participants_.emplace(participant.replica_id, participant).second);
165
166 // The first thread to acquire the lock is designated as the primary.
167 primary = !initialized_;
168
169 if (primary) {
170 initialized_ = true;
171 returned_blocking_counter_ =
172 std::make_shared<BlockingCounter>(key_.num_participants);
173 }
174 }
175
176 // Wait for all participants to arrive. Even though our copies are async and
177 // enqueued by just one thread, this is not optional! If we didn't wait for
178 // everyone, then we wouldn't be able to enqueue the copies at the correct
179 // point in their streams.
180 all_arrived_.DecrementCount();
181 WaitAndLogIfStuck(&all_arrived_, [&] {
182 return absl::StrFormat(
183 "participant for replica %d (stream %p, device %d) waiting for all "
184 "other participants to arrive: %s",
185 participant.replica_id, participant.stream,
186 participant.stream->parent()->device_ordinal(), key_.ToString());
187 });
188
189 // Schedule the copies between the devices. This is much easier to reason
190 // about if we schedule all of the copies from just one thread. The copies
191 // are async anyway, so the number of host threads we use isn't important.
192 if (primary) {
193 tensorflow::mutex_lock lock(mu_);
194 for (const auto& kv : participants_) {
195 const ParticipantData& src_participant = kv.second;
196 for (int64 dest_replica : src_participant.dest_replicas) {
197 const ParticipantData& dest_participant =
198 participants_.at(dest_replica);
199 EnqueueCopy(src_participant.src, src_participant.stream,
200 dest_participant.dest, dest_participant.stream);
201 }
202 }
203 }
204
205 return returned_blocking_counter_;
206 }
207
208 // Global map of Rendezvous objects. A thread participating in a collective op
209 // looks up its Rendezvous in this map to find the other threads that it's
210 // participating with.
211 //
212 // Rendezvous objects are one-time use, so they're removed from this map once
213 // we're through with them.
GlobalRendezvousMap()214 RefcountingHashMap<RendezvousKey, Rendezvous>& GlobalRendezvousMap() {
215 static auto& m = *new RefcountingHashMap<RendezvousKey, Rendezvous>();
216 return m;
217 }
218
219 } // anonymous namespace
220
CollectivePermuteThunk(ThunkInfo thunk_info,std::vector<std::pair<int64,int64>> source_target_pairs,const BufferAllocation::Slice & src,const BufferAllocation::Slice & dest)221 CollectivePermuteThunk::CollectivePermuteThunk(
222 ThunkInfo thunk_info,
223 std::vector<std::pair<int64, int64>> source_target_pairs,
224 const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest)
225 : Thunk(kCollectivePermute, thunk_info),
226 source_target_pairs_(std::move(source_target_pairs)),
227 src_(src),
228 dest_(dest) {}
229
ExecuteOnStream(const ExecuteParams & params)230 Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) {
231 auto op_profiler =
232 params.profiler->MakeScopedInstructionProfiler(profile_index());
233
234 // Rendezvous with the threads for all other devices that are participating in
235 // this CollectivePermute.
236 RendezvousKey key{params.run_id, params.device_assn->replica_count()};
237 auto rendezvous_factory = [](const RendezvousKey& key) {
238 return absl::make_unique<Rendezvous>(key);
239 };
240 std::shared_ptr<Rendezvous> rendezvous =
241 GlobalRendezvousMap().GetOrCreateIfAbsent(key, rendezvous_factory);
242
243 TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
244 params.GetGlobalDeviceId());
245 TF_ASSIGN_OR_RETURN(int64 replica_id,
246 params.device_assn->ReplicaIdForDevice(global_device_id));
247
248 // Figure out which replicas our data is copied to.
249 std::vector<int64> dest_replicas;
250 for (const auto& src_dest : source_target_pairs_) {
251 if (src_dest.first == replica_id) {
252 dest_replicas.push_back(src_dest.second);
253 }
254 }
255
256 auto src_addr = params.buffer_allocations->GetDeviceAddress(src_);
257 auto dest_addr = params.buffer_allocations->GetDeviceAddress(dest_);
258 ParticipantData participant{replica_id, params.stream, src_addr, dest_addr,
259 dest_replicas};
260 TF_ASSIGN_OR_RETURN(std::shared_ptr<BlockingCounter> final_sync,
261 rendezvous->SubmitParticipant(participant));
262
263 // If no replica writes into us (i.e. we aren't the target of any copies), our
264 // contract is that we zero our output.
265 if (absl::c_none_of(source_target_pairs_,
266 [&](std::pair<int64, int64> src_dest) {
267 return src_dest.second == replica_id;
268 })) {
269 params.stream->ThenMemZero(&dest_addr, dest_addr.size());
270 }
271
272 // Drop our reference to the Rendezvous and wait for all other threads to do
273 // the same. If we didn't do this, one of the threads could run past this
274 // point, reenter ExecuteOnStream for another collective-permute, and attempt
275 // to reuse the Rendezvous!
276 //
277 // An alternative way of accomplishing this goal would be to implement
278 // RefcountingHashMap::erase() and call it during SubmitParticipant. But
279 // erase() is deceptively complex to implement correctly.
280 rendezvous.reset();
281 final_sync->DecrementCount();
282 WaitAndLogIfStuck(final_sync.get(), [&] {
283 return absl::StrFormat(
284 "participant for replica %d (stream %p, device ordinal %d) waiting for "
285 "all threads to drop their reference to the rendezvous: %s",
286 participant.replica_id, participant.stream,
287 participant.stream->parent()->device_ordinal(), key.ToString());
288 });
289 return Status::OK();
290 }
291
292 } // namespace gpu
293 } // namespace xla
294